1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NVPTXAsmPrinter.h" 15 #include "MCTargetDesc/NVPTXBaseInfo.h" 16 #include "MCTargetDesc/NVPTXInstPrinter.h" 17 #include "MCTargetDesc/NVPTXMCAsmInfo.h" 18 #include "MCTargetDesc/NVPTXTargetStreamer.h" 19 #include "NVPTX.h" 20 #include "NVPTXMCExpr.h" 21 #include "NVPTXMachineFunctionInfo.h" 22 #include "NVPTXRegisterInfo.h" 23 #include "NVPTXSubtarget.h" 24 #include "NVPTXTargetMachine.h" 25 #include "NVPTXUtilities.h" 26 #include "TargetInfo/NVPTXTargetInfo.h" 27 #include "cl_common_defines.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/StringExtras.h" 35 #include "llvm/ADT/StringRef.h" 36 #include "llvm/ADT/Twine.h" 37 #include "llvm/Analysis/ConstantFolding.h" 38 #include "llvm/CodeGen/Analysis.h" 39 #include "llvm/CodeGen/MachineBasicBlock.h" 40 #include "llvm/CodeGen/MachineFrameInfo.h" 41 #include "llvm/CodeGen/MachineFunction.h" 42 #include "llvm/CodeGen/MachineInstr.h" 43 #include "llvm/CodeGen/MachineLoopInfo.h" 44 #include "llvm/CodeGen/MachineModuleInfo.h" 45 #include "llvm/CodeGen/MachineOperand.h" 46 #include "llvm/CodeGen/MachineRegisterInfo.h" 47 #include "llvm/CodeGen/TargetRegisterInfo.h" 48 #include "llvm/CodeGen/ValueTypes.h" 49 #include "llvm/CodeGenTypes/MachineValueType.h" 50 #include "llvm/IR/Attributes.h" 51 #include "llvm/IR/BasicBlock.h" 52 #include "llvm/IR/Constant.h" 53 #include "llvm/IR/Constants.h" 54 #include "llvm/IR/DataLayout.h" 55 #include "llvm/IR/DebugInfo.h" 56 #include "llvm/IR/DebugInfoMetadata.h" 57 #include "llvm/IR/DebugLoc.h" 58 #include "llvm/IR/DerivedTypes.h" 59 #include "llvm/IR/Function.h" 60 #include "llvm/IR/GlobalAlias.h" 61 #include "llvm/IR/GlobalValue.h" 62 #include "llvm/IR/GlobalVariable.h" 63 #include "llvm/IR/Instruction.h" 64 #include "llvm/IR/LLVMContext.h" 65 #include "llvm/IR/Module.h" 66 #include "llvm/IR/Operator.h" 67 #include "llvm/IR/Type.h" 68 #include "llvm/IR/User.h" 69 #include "llvm/MC/MCExpr.h" 70 #include "llvm/MC/MCInst.h" 71 #include "llvm/MC/MCInstrDesc.h" 72 #include "llvm/MC/MCStreamer.h" 73 #include "llvm/MC/MCSymbol.h" 74 #include "llvm/MC/TargetRegistry.h" 75 #include "llvm/Support/Alignment.h" 76 #include "llvm/Support/Casting.h" 77 #include "llvm/Support/CommandLine.h" 78 #include "llvm/Support/Endian.h" 79 #include "llvm/Support/ErrorHandling.h" 80 #include "llvm/Support/NativeFormatting.h" 81 #include "llvm/Support/raw_ostream.h" 82 #include "llvm/Target/TargetLoweringObjectFile.h" 83 #include "llvm/Target/TargetMachine.h" 84 #include "llvm/Transforms/Utils/UnrollLoop.h" 85 #include <cassert> 86 #include <cstdint> 87 #include <cstring> 88 #include <string> 89 #include <utility> 90 #include <vector> 91 92 using namespace llvm; 93 94 static cl::opt<bool> 95 LowerCtorDtor("nvptx-lower-global-ctor-dtor", 96 cl::desc("Lower GPU ctor / dtors to globals on the device."), 97 cl::init(false), cl::Hidden); 98 99 #define DEPOTNAME "__local_depot" 100 101 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V 102 /// depends. 103 static void 104 DiscoverDependentGlobals(const Value *V, 105 DenseSet<const GlobalVariable *> &Globals) { 106 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 107 Globals.insert(GV); 108 else { 109 if (const User *U = dyn_cast<User>(V)) { 110 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) { 111 DiscoverDependentGlobals(U->getOperand(i), Globals); 112 } 113 } 114 } 115 } 116 117 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable 118 /// instances to be emitted, but only after any dependents have been added 119 /// first.s 120 static void 121 VisitGlobalVariableForEmission(const GlobalVariable *GV, 122 SmallVectorImpl<const GlobalVariable *> &Order, 123 DenseSet<const GlobalVariable *> &Visited, 124 DenseSet<const GlobalVariable *> &Visiting) { 125 // Have we already visited this one? 126 if (Visited.count(GV)) 127 return; 128 129 // Do we have a circular dependency? 130 if (!Visiting.insert(GV).second) 131 report_fatal_error("Circular dependency found in global variable set"); 132 133 // Make sure we visit all dependents first 134 DenseSet<const GlobalVariable *> Others; 135 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i) 136 DiscoverDependentGlobals(GV->getOperand(i), Others); 137 138 for (const GlobalVariable *GV : Others) 139 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting); 140 141 // Now we can visit ourself 142 Order.push_back(GV); 143 Visited.insert(GV); 144 Visiting.erase(GV); 145 } 146 147 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) { 148 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(), 149 getSubtargetInfo().getFeatureBits()); 150 151 MCInst Inst; 152 lowerToMCInst(MI, Inst); 153 EmitToStreamer(*OutStreamer, Inst); 154 } 155 156 // Handle symbol backtracking for targets that do not support image handles 157 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI, 158 unsigned OpNo, MCOperand &MCOp) { 159 const MachineOperand &MO = MI->getOperand(OpNo); 160 const MCInstrDesc &MCID = MI->getDesc(); 161 162 if (MCID.TSFlags & NVPTXII::IsTexFlag) { 163 // This is a texture fetch, so operand 4 is a texref and operand 5 is 164 // a samplerref 165 if (OpNo == 4 && MO.isImm()) { 166 lowerImageHandleSymbol(MO.getImm(), MCOp); 167 return true; 168 } 169 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) { 170 lowerImageHandleSymbol(MO.getImm(), MCOp); 171 return true; 172 } 173 174 return false; 175 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { 176 unsigned VecSize = 177 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); 178 179 // For a surface load of vector size N, the Nth operand will be the surfref 180 if (OpNo == VecSize && MO.isImm()) { 181 lowerImageHandleSymbol(MO.getImm(), MCOp); 182 return true; 183 } 184 185 return false; 186 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { 187 // This is a surface store, so operand 0 is a surfref 188 if (OpNo == 0 && MO.isImm()) { 189 lowerImageHandleSymbol(MO.getImm(), MCOp); 190 return true; 191 } 192 193 return false; 194 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { 195 // This is a query, so operand 1 is a surfref/texref 196 if (OpNo == 1 && MO.isImm()) { 197 lowerImageHandleSymbol(MO.getImm(), MCOp); 198 return true; 199 } 200 201 return false; 202 } 203 204 return false; 205 } 206 207 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) { 208 // Ewwww 209 TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget()); 210 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM); 211 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>(); 212 StringRef Sym = MFI->getImageHandleSymbol(Index); 213 StringRef SymName = nvTM.getStrPool().save(Sym); 214 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName)); 215 } 216 217 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { 218 OutMI.setOpcode(MI->getOpcode()); 219 // Special: Do not mangle symbol operand of CALL_PROTOTYPE 220 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) { 221 const MachineOperand &MO = MI->getOperand(0); 222 OutMI.addOperand(GetSymbolRef( 223 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName())))); 224 return; 225 } 226 227 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { 228 const MachineOperand &MO = MI->getOperand(i); 229 230 MCOperand MCOp; 231 if (lowerImageHandleOperand(MI, i, MCOp)) { 232 OutMI.addOperand(MCOp); 233 continue; 234 } 235 236 if (lowerOperand(MO, MCOp)) 237 OutMI.addOperand(MCOp); 238 } 239 } 240 241 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, 242 MCOperand &MCOp) { 243 switch (MO.getType()) { 244 default: llvm_unreachable("unknown operand type"); 245 case MachineOperand::MO_Register: 246 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg())); 247 break; 248 case MachineOperand::MO_Immediate: 249 MCOp = MCOperand::createImm(MO.getImm()); 250 break; 251 case MachineOperand::MO_MachineBasicBlock: 252 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create( 253 MO.getMBB()->getSymbol(), OutContext)); 254 break; 255 case MachineOperand::MO_ExternalSymbol: 256 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); 257 break; 258 case MachineOperand::MO_GlobalAddress: 259 MCOp = GetSymbolRef(getSymbol(MO.getGlobal())); 260 break; 261 case MachineOperand::MO_FPImmediate: { 262 const ConstantFP *Cnt = MO.getFPImm(); 263 const APFloat &Val = Cnt->getValueAPF(); 264 265 switch (Cnt->getType()->getTypeID()) { 266 default: report_fatal_error("Unsupported FP type"); break; 267 case Type::HalfTyID: 268 MCOp = MCOperand::createExpr( 269 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); 270 break; 271 case Type::BFloatTyID: 272 MCOp = MCOperand::createExpr( 273 NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext)); 274 break; 275 case Type::FloatTyID: 276 MCOp = MCOperand::createExpr( 277 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); 278 break; 279 case Type::DoubleTyID: 280 MCOp = MCOperand::createExpr( 281 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); 282 break; 283 } 284 break; 285 } 286 } 287 return true; 288 } 289 290 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { 291 if (Register::isVirtualRegister(Reg)) { 292 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 293 294 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC]; 295 unsigned RegNum = RegMap[Reg]; 296 297 // Encode the register class in the upper 4 bits 298 // Must be kept in sync with NVPTXInstPrinter::printRegName 299 unsigned Ret = 0; 300 if (RC == &NVPTX::Int1RegsRegClass) { 301 Ret = (1 << 28); 302 } else if (RC == &NVPTX::Int16RegsRegClass) { 303 Ret = (2 << 28); 304 } else if (RC == &NVPTX::Int32RegsRegClass) { 305 Ret = (3 << 28); 306 } else if (RC == &NVPTX::Int64RegsRegClass) { 307 Ret = (4 << 28); 308 } else if (RC == &NVPTX::Float32RegsRegClass) { 309 Ret = (5 << 28); 310 } else if (RC == &NVPTX::Float64RegsRegClass) { 311 Ret = (6 << 28); 312 } else if (RC == &NVPTX::Int128RegsRegClass) { 313 Ret = (7 << 28); 314 } else { 315 report_fatal_error("Bad register class"); 316 } 317 318 // Insert the vreg number 319 Ret |= (RegNum & 0x0FFFFFFF); 320 return Ret; 321 } else { 322 // Some special-use registers are actually physical registers. 323 // Encode this as the register class ID of 0 and the real register ID. 324 return Reg & 0x0FFFFFFF; 325 } 326 } 327 328 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) { 329 const MCExpr *Expr; 330 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None, 331 OutContext); 332 return MCOperand::createExpr(Expr); 333 } 334 335 static bool ShouldPassAsArray(Type *Ty) { 336 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) || 337 Ty->isHalfTy() || Ty->isBFloatTy(); 338 } 339 340 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { 341 const DataLayout &DL = getDataLayout(); 342 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 343 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 344 345 Type *Ty = F->getReturnType(); 346 347 bool isABI = (STI.getSmVersion() >= 20); 348 349 if (Ty->getTypeID() == Type::VoidTyID) 350 return; 351 O << " ("; 352 353 if (isABI) { 354 if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) && 355 !ShouldPassAsArray(Ty)) { 356 unsigned size = 0; 357 if (auto *ITy = dyn_cast<IntegerType>(Ty)) { 358 size = ITy->getBitWidth(); 359 } else { 360 assert(Ty->isFloatingPointTy() && "Floating point type expected here"); 361 size = Ty->getPrimitiveSizeInBits(); 362 } 363 size = promoteScalarArgumentSize(size); 364 O << ".param .b" << size << " func_retval0"; 365 } else if (isa<PointerType>(Ty)) { 366 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits() 367 << " func_retval0"; 368 } else if (ShouldPassAsArray(Ty)) { 369 unsigned totalsz = DL.getTypeAllocSize(Ty); 370 Align RetAlignment = TLI->getFunctionArgumentAlignment( 371 F, Ty, AttributeList::ReturnIndex, DL); 372 O << ".param .align " << RetAlignment.value() << " .b8 func_retval0[" 373 << totalsz << "]"; 374 } else 375 llvm_unreachable("Unknown return type"); 376 } else { 377 SmallVector<EVT, 16> vtparts; 378 ComputeValueVTs(*TLI, DL, Ty, vtparts); 379 unsigned idx = 0; 380 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 381 unsigned elems = 1; 382 EVT elemtype = vtparts[i]; 383 if (vtparts[i].isVector()) { 384 elems = vtparts[i].getVectorNumElements(); 385 elemtype = vtparts[i].getVectorElementType(); 386 } 387 388 for (unsigned j = 0, je = elems; j != je; ++j) { 389 unsigned sz = elemtype.getSizeInBits(); 390 if (elemtype.isInteger()) 391 sz = promoteScalarArgumentSize(sz); 392 O << ".reg .b" << sz << " func_retval" << idx; 393 if (j < je - 1) 394 O << ", "; 395 ++idx; 396 } 397 if (i < e - 1) 398 O << ", "; 399 } 400 } 401 O << ") "; 402 } 403 404 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF, 405 raw_ostream &O) { 406 const Function &F = MF.getFunction(); 407 printReturnValStr(&F, O); 408 } 409 410 // Return true if MBB is the header of a loop marked with 411 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1. 412 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll( 413 const MachineBasicBlock &MBB) const { 414 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI(); 415 // We insert .pragma "nounroll" only to the loop header. 416 if (!LI.isLoopHeader(&MBB)) 417 return false; 418 419 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore, 420 // we iterate through each back edge of the loop with header MBB, and check 421 // whether its metadata contains llvm.loop.unroll.disable. 422 for (const MachineBasicBlock *PMBB : MBB.predecessors()) { 423 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) { 424 // Edges from other loops to MBB are not back edges. 425 continue; 426 } 427 if (const BasicBlock *PBB = PMBB->getBasicBlock()) { 428 if (MDNode *LoopID = 429 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) { 430 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable")) 431 return true; 432 if (MDNode *UnrollCountMD = 433 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) { 434 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1)) 435 ->isOne()) 436 return true; 437 } 438 } 439 } 440 } 441 return false; 442 } 443 444 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { 445 AsmPrinter::emitBasicBlockStart(MBB); 446 if (isLoopHeaderOfNoUnroll(MBB)) 447 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n")); 448 } 449 450 void NVPTXAsmPrinter::emitFunctionEntryLabel() { 451 SmallString<128> Str; 452 raw_svector_ostream O(Str); 453 454 if (!GlobalsEmitted) { 455 emitGlobals(*MF->getFunction().getParent()); 456 GlobalsEmitted = true; 457 } 458 459 // Set up 460 MRI = &MF->getRegInfo(); 461 F = &MF->getFunction(); 462 emitLinkageDirective(F, O); 463 if (isKernelFunction(*F)) 464 O << ".entry "; 465 else { 466 O << ".func "; 467 printReturnValStr(*MF, O); 468 } 469 470 CurrentFnSym->print(O, MAI); 471 472 emitFunctionParamList(F, O); 473 O << "\n"; 474 475 if (isKernelFunction(*F)) 476 emitKernelFunctionDirectives(*F, O); 477 478 if (shouldEmitPTXNoReturn(F, TM)) 479 O << ".noreturn"; 480 481 OutStreamer->emitRawText(O.str()); 482 483 VRegMapping.clear(); 484 // Emit open brace for function body. 485 OutStreamer->emitRawText(StringRef("{\n")); 486 setAndEmitFunctionVirtualRegisters(*MF); 487 encodeDebugInfoRegisterNumbers(*MF); 488 // Emit initial .loc debug directive for correct relocation symbol data. 489 if (const DISubprogram *SP = MF->getFunction().getSubprogram()) { 490 assert(SP->getUnit()); 491 if (!SP->getUnit()->isDebugDirectivesOnly()) 492 emitInitialRawDwarfLocDirective(*MF); 493 } 494 } 495 496 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) { 497 bool Result = AsmPrinter::runOnMachineFunction(F); 498 // Emit closing brace for the body of function F. 499 // The closing brace must be emitted here because we need to emit additional 500 // debug labels/data after the last basic block. 501 // We need to emit the closing brace here because we don't have function that 502 // finished emission of the function body. 503 OutStreamer->emitRawText(StringRef("}\n")); 504 return Result; 505 } 506 507 void NVPTXAsmPrinter::emitFunctionBodyStart() { 508 SmallString<128> Str; 509 raw_svector_ostream O(Str); 510 emitDemotedVars(&MF->getFunction(), O); 511 OutStreamer->emitRawText(O.str()); 512 } 513 514 void NVPTXAsmPrinter::emitFunctionBodyEnd() { 515 VRegMapping.clear(); 516 } 517 518 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const { 519 SmallString<128> Str; 520 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber(); 521 return OutContext.getOrCreateSymbol(Str); 522 } 523 524 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const { 525 Register RegNo = MI->getOperand(0).getReg(); 526 if (RegNo.isVirtual()) { 527 OutStreamer->AddComment(Twine("implicit-def: ") + 528 getVirtualRegisterName(RegNo)); 529 } else { 530 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>(); 531 OutStreamer->AddComment(Twine("implicit-def: ") + 532 STI.getRegisterInfo()->getName(RegNo)); 533 } 534 OutStreamer->addBlankLine(); 535 } 536 537 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, 538 raw_ostream &O) const { 539 // If the NVVM IR has some of reqntid* specified, then output 540 // the reqntid directive, and set the unspecified ones to 1. 541 // If none of Reqntid* is specified, don't output reqntid directive. 542 std::optional<unsigned> Reqntidx = getReqNTIDx(F); 543 std::optional<unsigned> Reqntidy = getReqNTIDy(F); 544 std::optional<unsigned> Reqntidz = getReqNTIDz(F); 545 546 if (Reqntidx || Reqntidy || Reqntidz) 547 O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1) 548 << ", " << Reqntidz.value_or(1) << "\n"; 549 550 // If the NVVM IR has some of maxntid* specified, then output 551 // the maxntid directive, and set the unspecified ones to 1. 552 // If none of maxntid* is specified, don't output maxntid directive. 553 std::optional<unsigned> Maxntidx = getMaxNTIDx(F); 554 std::optional<unsigned> Maxntidy = getMaxNTIDy(F); 555 std::optional<unsigned> Maxntidz = getMaxNTIDz(F); 556 557 if (Maxntidx || Maxntidy || Maxntidz) 558 O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1) 559 << ", " << Maxntidz.value_or(1) << "\n"; 560 561 if (const auto Mincta = getMinCTASm(F)) 562 O << ".minnctapersm " << *Mincta << "\n"; 563 564 if (const auto Maxnreg = getMaxNReg(F)) 565 O << ".maxnreg " << *Maxnreg << "\n"; 566 567 // .maxclusterrank directive requires SM_90 or higher, make sure that we 568 // filter it out for lower SM versions, as it causes a hard ptxas crash. 569 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 570 const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 571 572 if (STI->getSmVersion() >= 90) { 573 std::optional<unsigned> ClusterX = getClusterDimx(F); 574 std::optional<unsigned> ClusterY = getClusterDimy(F); 575 std::optional<unsigned> ClusterZ = getClusterDimz(F); 576 577 if (ClusterX || ClusterY || ClusterZ) { 578 O << ".explicitcluster\n"; 579 if (ClusterX.value_or(1) != 0) { 580 assert(ClusterY.value_or(1) && ClusterZ.value_or(1) && 581 "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z " 582 "should be non-zero as well"); 583 584 O << ".reqnctapercluster " << ClusterX.value_or(1) << ", " 585 << ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n"; 586 } else { 587 assert(!ClusterY.value_or(1) && !ClusterZ.value_or(1) && 588 "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z " 589 "should be 0 as well"); 590 } 591 } 592 if (const auto Maxclusterrank = getMaxClusterRank(F)) 593 O << ".maxclusterrank " << *Maxclusterrank << "\n"; 594 } 595 } 596 597 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const { 598 const TargetRegisterClass *RC = MRI->getRegClass(Reg); 599 600 std::string Name; 601 raw_string_ostream NameStr(Name); 602 603 VRegRCMap::const_iterator I = VRegMapping.find(RC); 604 assert(I != VRegMapping.end() && "Bad register class"); 605 const DenseMap<unsigned, unsigned> &RegMap = I->second; 606 607 VRegMap::const_iterator VI = RegMap.find(Reg); 608 assert(VI != RegMap.end() && "Bad virtual register"); 609 unsigned MappedVR = VI->second; 610 611 NameStr << getNVPTXRegClassStr(RC) << MappedVR; 612 613 return Name; 614 } 615 616 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, 617 raw_ostream &O) { 618 O << getVirtualRegisterName(vr); 619 } 620 621 void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA, 622 raw_ostream &O) { 623 const Function *F = dyn_cast_or_null<Function>(GA->getAliaseeObject()); 624 if (!F || isKernelFunction(*F) || F->isDeclaration()) 625 report_fatal_error( 626 "NVPTX aliasee must be a non-kernel function definition"); 627 628 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() || 629 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage()) 630 report_fatal_error("NVPTX aliasee must not be '.weak'"); 631 632 emitDeclarationWithName(F, getSymbol(GA), O); 633 } 634 635 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { 636 emitDeclarationWithName(F, getSymbol(F), O); 637 } 638 639 void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S, 640 raw_ostream &O) { 641 emitLinkageDirective(F, O); 642 if (isKernelFunction(*F)) 643 O << ".entry "; 644 else 645 O << ".func "; 646 printReturnValStr(F, O); 647 S->print(O, MAI); 648 O << "\n"; 649 emitFunctionParamList(F, O); 650 O << "\n"; 651 if (shouldEmitPTXNoReturn(F, TM)) 652 O << ".noreturn"; 653 O << ";\n"; 654 } 655 656 static bool usedInGlobalVarDef(const Constant *C) { 657 if (!C) 658 return false; 659 660 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) { 661 return GV->getName() != "llvm.used"; 662 } 663 664 for (const User *U : C->users()) 665 if (const Constant *C = dyn_cast<Constant>(U)) 666 if (usedInGlobalVarDef(C)) 667 return true; 668 669 return false; 670 } 671 672 static bool usedInOneFunc(const User *U, Function const *&oneFunc) { 673 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) { 674 if (othergv->getName() == "llvm.used") 675 return true; 676 } 677 678 if (const Instruction *instr = dyn_cast<Instruction>(U)) { 679 if (instr->getParent() && instr->getParent()->getParent()) { 680 const Function *curFunc = instr->getParent()->getParent(); 681 if (oneFunc && (curFunc != oneFunc)) 682 return false; 683 oneFunc = curFunc; 684 return true; 685 } else 686 return false; 687 } 688 689 for (const User *UU : U->users()) 690 if (!usedInOneFunc(UU, oneFunc)) 691 return false; 692 693 return true; 694 } 695 696 /* Find out if a global variable can be demoted to local scope. 697 * Currently, this is valid for CUDA shared variables, which have local 698 * scope and global lifetime. So the conditions to check are : 699 * 1. Is the global variable in shared address space? 700 * 2. Does it have local linkage? 701 * 3. Is the global variable referenced only in one function? 702 */ 703 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { 704 if (!gv->hasLocalLinkage()) 705 return false; 706 PointerType *Pty = gv->getType(); 707 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED) 708 return false; 709 710 const Function *oneFunc = nullptr; 711 712 bool flag = usedInOneFunc(gv, oneFunc); 713 if (!flag) 714 return false; 715 if (!oneFunc) 716 return false; 717 f = oneFunc; 718 return true; 719 } 720 721 static bool useFuncSeen(const Constant *C, 722 DenseMap<const Function *, bool> &seenMap) { 723 for (const User *U : C->users()) { 724 if (const Constant *cu = dyn_cast<Constant>(U)) { 725 if (useFuncSeen(cu, seenMap)) 726 return true; 727 } else if (const Instruction *I = dyn_cast<Instruction>(U)) { 728 const BasicBlock *bb = I->getParent(); 729 if (!bb) 730 continue; 731 const Function *caller = bb->getParent(); 732 if (!caller) 733 continue; 734 if (seenMap.contains(caller)) 735 return true; 736 } 737 } 738 return false; 739 } 740 741 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { 742 DenseMap<const Function *, bool> seenMap; 743 for (const Function &F : M) { 744 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) { 745 emitDeclaration(&F, O); 746 continue; 747 } 748 749 if (F.isDeclaration()) { 750 if (F.use_empty()) 751 continue; 752 if (F.getIntrinsicID()) 753 continue; 754 emitDeclaration(&F, O); 755 continue; 756 } 757 for (const User *U : F.users()) { 758 if (const Constant *C = dyn_cast<Constant>(U)) { 759 if (usedInGlobalVarDef(C)) { 760 // The use is in the initialization of a global variable 761 // that is a function pointer, so print a declaration 762 // for the original function 763 emitDeclaration(&F, O); 764 break; 765 } 766 // Emit a declaration of this function if the function that 767 // uses this constant expr has already been seen. 768 if (useFuncSeen(C, seenMap)) { 769 emitDeclaration(&F, O); 770 break; 771 } 772 } 773 774 if (!isa<Instruction>(U)) 775 continue; 776 const Instruction *instr = cast<Instruction>(U); 777 const BasicBlock *bb = instr->getParent(); 778 if (!bb) 779 continue; 780 const Function *caller = bb->getParent(); 781 if (!caller) 782 continue; 783 784 // If a caller has already been seen, then the caller is 785 // appearing in the module before the callee. so print out 786 // a declaration for the callee. 787 if (seenMap.contains(caller)) { 788 emitDeclaration(&F, O); 789 break; 790 } 791 } 792 seenMap[&F] = true; 793 } 794 for (const GlobalAlias &GA : M.aliases()) 795 emitAliasDeclaration(&GA, O); 796 } 797 798 static bool isEmptyXXStructor(GlobalVariable *GV) { 799 if (!GV) return true; 800 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer()); 801 if (!InitList) return true; // Not an array; we don't know how to parse. 802 return InitList->getNumOperands() == 0; 803 } 804 805 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) { 806 // Construct a default subtarget off of the TargetMachine defaults. The 807 // rest of NVPTX isn't friendly to change subtargets per function and 808 // so the default TargetMachine will have all of the options. 809 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 810 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl()); 811 SmallString<128> Str1; 812 raw_svector_ostream OS1(Str1); 813 814 // Emit header before any dwarf directives are emitted below. 815 emitHeader(M, OS1, *STI); 816 OutStreamer->emitRawText(OS1.str()); 817 } 818 819 bool NVPTXAsmPrinter::doInitialization(Module &M) { 820 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 821 const NVPTXSubtarget &STI = 822 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 823 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30)) 824 report_fatal_error(".alias requires PTX version >= 6.3 and sm_30"); 825 826 // OpenMP supports NVPTX global constructors and destructors. 827 bool IsOpenMP = M.getModuleFlag("openmp") != nullptr; 828 829 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) && 830 !LowerCtorDtor && !IsOpenMP) { 831 report_fatal_error( 832 "Module has a nontrivial global ctor, which NVPTX does not support."); 833 return true; // error 834 } 835 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) && 836 !LowerCtorDtor && !IsOpenMP) { 837 report_fatal_error( 838 "Module has a nontrivial global dtor, which NVPTX does not support."); 839 return true; // error 840 } 841 842 // We need to call the parent's one explicitly. 843 bool Result = AsmPrinter::doInitialization(M); 844 845 GlobalsEmitted = false; 846 847 return Result; 848 } 849 850 void NVPTXAsmPrinter::emitGlobals(const Module &M) { 851 SmallString<128> Str2; 852 raw_svector_ostream OS2(Str2); 853 854 emitDeclarations(M, OS2); 855 856 // As ptxas does not support forward references of globals, we need to first 857 // sort the list of module-level globals in def-use order. We visit each 858 // global variable in order, and ensure that we emit it *after* its dependent 859 // globals. We use a little extra memory maintaining both a set and a list to 860 // have fast searches while maintaining a strict ordering. 861 SmallVector<const GlobalVariable *, 8> Globals; 862 DenseSet<const GlobalVariable *> GVVisited; 863 DenseSet<const GlobalVariable *> GVVisiting; 864 865 // Visit each global variable, in order 866 for (const GlobalVariable &I : M.globals()) 867 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting); 868 869 assert(GVVisited.size() == M.global_size() && "Missed a global variable"); 870 assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); 871 872 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 873 const NVPTXSubtarget &STI = 874 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 875 876 // Print out module-level global variables in proper order 877 for (const GlobalVariable *GV : Globals) 878 printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI); 879 880 OS2 << '\n'; 881 882 OutStreamer->emitRawText(OS2.str()); 883 } 884 885 void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) { 886 SmallString<128> Str; 887 raw_svector_ostream OS(Str); 888 889 MCSymbol *Name = getSymbol(&GA); 890 891 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName() 892 << ";\n"; 893 894 OutStreamer->emitRawText(OS.str()); 895 } 896 897 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, 898 const NVPTXSubtarget &STI) { 899 O << "//\n"; 900 O << "// Generated by LLVM NVPTX Back-End\n"; 901 O << "//\n"; 902 O << "\n"; 903 904 unsigned PTXVersion = STI.getPTXVersion(); 905 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"; 906 907 O << ".target "; 908 O << STI.getTargetName(); 909 910 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 911 if (NTM.getDrvInterface() == NVPTX::NVCL) 912 O << ", texmode_independent"; 913 914 bool HasFullDebugInfo = false; 915 for (DICompileUnit *CU : M.debug_compile_units()) { 916 switch(CU->getEmissionKind()) { 917 case DICompileUnit::NoDebug: 918 case DICompileUnit::DebugDirectivesOnly: 919 break; 920 case DICompileUnit::LineTablesOnly: 921 case DICompileUnit::FullDebug: 922 HasFullDebugInfo = true; 923 break; 924 } 925 if (HasFullDebugInfo) 926 break; 927 } 928 if (HasFullDebugInfo) 929 O << ", debug"; 930 931 O << "\n"; 932 933 O << ".address_size "; 934 if (NTM.is64Bit()) 935 O << "64"; 936 else 937 O << "32"; 938 O << "\n"; 939 940 O << "\n"; 941 } 942 943 bool NVPTXAsmPrinter::doFinalization(Module &M) { 944 // If we did not emit any functions, then the global declarations have not 945 // yet been emitted. 946 if (!GlobalsEmitted) { 947 emitGlobals(M); 948 GlobalsEmitted = true; 949 } 950 951 // call doFinalization 952 bool ret = AsmPrinter::doFinalization(M); 953 954 clearAnnotationCache(&M); 955 956 auto *TS = 957 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer()); 958 // Close the last emitted section 959 if (hasDebugInfo()) { 960 TS->closeLastSection(); 961 // Emit empty .debug_macinfo section for better support of the empty files. 962 OutStreamer->emitRawText("\t.section\t.debug_macinfo\t{\t}"); 963 } 964 965 // Output last DWARF .file directives, if any. 966 TS->outputDwarfFileDirectives(); 967 968 return ret; 969 } 970 971 // This function emits appropriate linkage directives for 972 // functions and global variables. 973 // 974 // extern function declaration -> .extern 975 // extern function definition -> .visible 976 // external global variable with init -> .visible 977 // external without init -> .extern 978 // appending -> not allowed, assert. 979 // for any linkage other than 980 // internal, private, linker_private, 981 // linker_private_weak, linker_private_weak_def_auto, 982 // we emit -> .weak. 983 984 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, 985 raw_ostream &O) { 986 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) { 987 if (V->hasExternalLinkage()) { 988 if (isa<GlobalVariable>(V)) { 989 const GlobalVariable *GVar = cast<GlobalVariable>(V); 990 if (GVar) { 991 if (GVar->hasInitializer()) 992 O << ".visible "; 993 else 994 O << ".extern "; 995 } 996 } else if (V->isDeclaration()) 997 O << ".extern "; 998 else 999 O << ".visible "; 1000 } else if (V->hasAppendingLinkage()) { 1001 std::string msg; 1002 msg.append("Error: "); 1003 msg.append("Symbol "); 1004 if (V->hasName()) 1005 msg.append(std::string(V->getName())); 1006 msg.append("has unsupported appending linkage type"); 1007 llvm_unreachable(msg.c_str()); 1008 } else if (!V->hasInternalLinkage() && 1009 !V->hasPrivateLinkage()) { 1010 O << ".weak "; 1011 } 1012 } 1013 } 1014 1015 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, 1016 raw_ostream &O, bool processDemoted, 1017 const NVPTXSubtarget &STI) { 1018 // Skip meta data 1019 if (GVar->hasSection()) { 1020 if (GVar->getSection() == "llvm.metadata") 1021 return; 1022 } 1023 1024 // Skip LLVM intrinsic global variables 1025 if (GVar->getName().starts_with("llvm.") || 1026 GVar->getName().starts_with("nvvm.")) 1027 return; 1028 1029 const DataLayout &DL = getDataLayout(); 1030 1031 // GlobalVariables are always constant pointers themselves. 1032 Type *ETy = GVar->getValueType(); 1033 1034 if (GVar->hasExternalLinkage()) { 1035 if (GVar->hasInitializer()) 1036 O << ".visible "; 1037 else 1038 O << ".extern "; 1039 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() && 1040 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) { 1041 O << ".common "; 1042 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() || 1043 GVar->hasAvailableExternallyLinkage() || 1044 GVar->hasCommonLinkage()) { 1045 O << ".weak "; 1046 } 1047 1048 if (isTexture(*GVar)) { 1049 O << ".global .texref " << getTextureName(*GVar) << ";\n"; 1050 return; 1051 } 1052 1053 if (isSurface(*GVar)) { 1054 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n"; 1055 return; 1056 } 1057 1058 if (GVar->isDeclaration()) { 1059 // (extern) declarations, no definition or initializer 1060 // Currently the only known declaration is for an automatic __local 1061 // (.shared) promoted to global. 1062 emitPTXGlobalVariable(GVar, O, STI); 1063 O << ";\n"; 1064 return; 1065 } 1066 1067 if (isSampler(*GVar)) { 1068 O << ".global .samplerref " << getSamplerName(*GVar); 1069 1070 const Constant *Initializer = nullptr; 1071 if (GVar->hasInitializer()) 1072 Initializer = GVar->getInitializer(); 1073 const ConstantInt *CI = nullptr; 1074 if (Initializer) 1075 CI = dyn_cast<ConstantInt>(Initializer); 1076 if (CI) { 1077 unsigned sample = CI->getZExtValue(); 1078 1079 O << " = { "; 1080 1081 for (int i = 0, 1082 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE); 1083 i < 3; i++) { 1084 O << "addr_mode_" << i << " = "; 1085 switch (addr) { 1086 case 0: 1087 O << "wrap"; 1088 break; 1089 case 1: 1090 O << "clamp_to_border"; 1091 break; 1092 case 2: 1093 O << "clamp_to_edge"; 1094 break; 1095 case 3: 1096 O << "wrap"; 1097 break; 1098 case 4: 1099 O << "mirror"; 1100 break; 1101 } 1102 O << ", "; 1103 } 1104 O << "filter_mode = "; 1105 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) { 1106 case 0: 1107 O << "nearest"; 1108 break; 1109 case 1: 1110 O << "linear"; 1111 break; 1112 case 2: 1113 llvm_unreachable("Anisotropic filtering is not supported"); 1114 default: 1115 O << "nearest"; 1116 break; 1117 } 1118 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) { 1119 O << ", force_unnormalized_coords = 1"; 1120 } 1121 O << " }"; 1122 } 1123 1124 O << ";\n"; 1125 return; 1126 } 1127 1128 if (GVar->hasPrivateLinkage()) { 1129 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0) 1130 return; 1131 1132 // FIXME - need better way (e.g. Metadata) to avoid generating this global 1133 if (strncmp(GVar->getName().data(), "filename", 8) == 0) 1134 return; 1135 if (GVar->use_empty()) 1136 return; 1137 } 1138 1139 const Function *demotedFunc = nullptr; 1140 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) { 1141 O << "// " << GVar->getName() << " has been demoted\n"; 1142 localDecls[demotedFunc].push_back(GVar); 1143 return; 1144 } 1145 1146 O << "."; 1147 emitPTXAddressSpace(GVar->getAddressSpace(), O); 1148 1149 if (isManaged(*GVar)) { 1150 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1151 report_fatal_error( 1152 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1153 } 1154 O << " .attribute(.managed)"; 1155 } 1156 1157 if (MaybeAlign A = GVar->getAlign()) 1158 O << " .align " << A->value(); 1159 else 1160 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1161 1162 if (ETy->isFloatingPointTy() || ETy->isPointerTy() || 1163 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { 1164 O << " ."; 1165 // Special case: ABI requires that we use .u8 for predicates 1166 if (ETy->isIntegerTy(1)) 1167 O << "u8"; 1168 else 1169 O << getPTXFundamentalTypeStr(ETy, false); 1170 O << " "; 1171 getSymbol(GVar)->print(O, MAI); 1172 1173 // Ptx allows variable initilization only for constant and global state 1174 // spaces. 1175 if (GVar->hasInitializer()) { 1176 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1177 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) { 1178 const Constant *Initializer = GVar->getInitializer(); 1179 // 'undef' is treated as there is no value specified. 1180 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) { 1181 O << " = "; 1182 printScalarConstant(Initializer, O); 1183 } 1184 } else { 1185 // The frontend adds zero-initializer to device and constant variables 1186 // that don't have an initial value, and UndefValue to shared 1187 // variables, so skip warning for this case. 1188 if (!GVar->getInitializer()->isNullValue() && 1189 !isa<UndefValue>(GVar->getInitializer())) { 1190 report_fatal_error("initial value of '" + GVar->getName() + 1191 "' is not allowed in addrspace(" + 1192 Twine(GVar->getAddressSpace()) + ")"); 1193 } 1194 } 1195 } 1196 } else { 1197 uint64_t ElementSize = 0; 1198 1199 // Although PTX has direct support for struct type and array type and 1200 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for 1201 // targets that support these high level field accesses. Structs, arrays 1202 // and vectors are lowered into arrays of bytes. 1203 switch (ETy->getTypeID()) { 1204 case Type::IntegerTyID: // Integers larger than 64 bits 1205 case Type::StructTyID: 1206 case Type::ArrayTyID: 1207 case Type::FixedVectorTyID: 1208 ElementSize = DL.getTypeStoreSize(ETy); 1209 // Ptx allows variable initilization only for constant and 1210 // global state spaces. 1211 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || 1212 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) && 1213 GVar->hasInitializer()) { 1214 const Constant *Initializer = GVar->getInitializer(); 1215 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) { 1216 AggBuffer aggBuffer(ElementSize, *this); 1217 bufferAggregateConstant(Initializer, &aggBuffer); 1218 if (aggBuffer.numSymbols()) { 1219 unsigned int ptrSize = MAI->getCodePointerSize(); 1220 if (ElementSize % ptrSize || 1221 !aggBuffer.allSymbolsAligned(ptrSize)) { 1222 // Print in bytes and use the mask() operator for pointers. 1223 if (!STI.hasMaskOperator()) 1224 report_fatal_error( 1225 "initialized packed aggregate with pointers '" + 1226 GVar->getName() + 1227 "' requires at least PTX ISA version 7.1"); 1228 O << " .u8 "; 1229 getSymbol(GVar)->print(O, MAI); 1230 O << "[" << ElementSize << "] = {"; 1231 aggBuffer.printBytes(O); 1232 O << "}"; 1233 } else { 1234 O << " .u" << ptrSize * 8 << " "; 1235 getSymbol(GVar)->print(O, MAI); 1236 O << "[" << ElementSize / ptrSize << "] = {"; 1237 aggBuffer.printWords(O); 1238 O << "}"; 1239 } 1240 } else { 1241 O << " .b8 "; 1242 getSymbol(GVar)->print(O, MAI); 1243 O << "[" << ElementSize << "] = {"; 1244 aggBuffer.printBytes(O); 1245 O << "}"; 1246 } 1247 } else { 1248 O << " .b8 "; 1249 getSymbol(GVar)->print(O, MAI); 1250 if (ElementSize) { 1251 O << "["; 1252 O << ElementSize; 1253 O << "]"; 1254 } 1255 } 1256 } else { 1257 O << " .b8 "; 1258 getSymbol(GVar)->print(O, MAI); 1259 if (ElementSize) { 1260 O << "["; 1261 O << ElementSize; 1262 O << "]"; 1263 } 1264 } 1265 break; 1266 default: 1267 llvm_unreachable("type not supported yet"); 1268 } 1269 } 1270 O << ";\n"; 1271 } 1272 1273 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) { 1274 const Value *v = Symbols[nSym]; 1275 const Value *v0 = SymbolsBeforeStripping[nSym]; 1276 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 1277 MCSymbol *Name = AP.getSymbol(GVar); 1278 PointerType *PTy = dyn_cast<PointerType>(v0->getType()); 1279 // Is v0 a generic pointer? 1280 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0; 1281 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) { 1282 os << "generic("; 1283 Name->print(os, AP.MAI); 1284 os << ")"; 1285 } else { 1286 Name->print(os, AP.MAI); 1287 } 1288 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) { 1289 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false); 1290 AP.printMCExpr(*Expr, os); 1291 } else 1292 llvm_unreachable("symbol type unknown"); 1293 } 1294 1295 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) { 1296 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1297 // Do not emit trailing zero initializers. They will be zero-initialized by 1298 // ptxas. This saves on both space requirements for the generated PTX and on 1299 // memory use by ptxas. (See: 1300 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space) 1301 unsigned int InitializerCount = size; 1302 // TODO: symbols make this harder, but it would still be good to trim trailing 1303 // 0s for aggs with symbols as well. 1304 if (numSymbols() == 0) 1305 while (InitializerCount >= 1 && !buffer[InitializerCount - 1]) 1306 InitializerCount--; 1307 1308 symbolPosInBuffer.push_back(InitializerCount); 1309 unsigned int nSym = 0; 1310 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1311 for (unsigned int pos = 0; pos < InitializerCount;) { 1312 if (pos) 1313 os << ", "; 1314 if (pos != nextSymbolPos) { 1315 os << (unsigned int)buffer[pos]; 1316 ++pos; 1317 continue; 1318 } 1319 // Generate a per-byte mask() operator for the symbol, which looks like: 1320 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...}; 1321 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers 1322 std::string symText; 1323 llvm::raw_string_ostream oss(symText); 1324 printSymbol(nSym, oss); 1325 for (unsigned i = 0; i < ptrSize; ++i) { 1326 if (i) 1327 os << ", "; 1328 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper); 1329 os << "(" << symText << ")"; 1330 } 1331 pos += ptrSize; 1332 nextSymbolPos = symbolPosInBuffer[++nSym]; 1333 assert(nextSymbolPos >= pos); 1334 } 1335 } 1336 1337 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) { 1338 unsigned int ptrSize = AP.MAI->getCodePointerSize(); 1339 symbolPosInBuffer.push_back(size); 1340 unsigned int nSym = 0; 1341 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 1342 assert(nextSymbolPos % ptrSize == 0); 1343 for (unsigned int pos = 0; pos < size; pos += ptrSize) { 1344 if (pos) 1345 os << ", "; 1346 if (pos == nextSymbolPos) { 1347 printSymbol(nSym, os); 1348 nextSymbolPos = symbolPosInBuffer[++nSym]; 1349 assert(nextSymbolPos % ptrSize == 0); 1350 assert(nextSymbolPos >= pos + ptrSize); 1351 } else if (ptrSize == 4) 1352 os << support::endian::read32le(&buffer[pos]); 1353 else 1354 os << support::endian::read64le(&buffer[pos]); 1355 } 1356 } 1357 1358 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) { 1359 auto It = localDecls.find(f); 1360 if (It == localDecls.end()) 1361 return; 1362 1363 std::vector<const GlobalVariable *> &gvars = It->second; 1364 1365 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM); 1366 const NVPTXSubtarget &STI = 1367 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl()); 1368 1369 for (const GlobalVariable *GV : gvars) { 1370 O << "\t// demoted variable\n\t"; 1371 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI); 1372 } 1373 } 1374 1375 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace, 1376 raw_ostream &O) const { 1377 switch (AddressSpace) { 1378 case ADDRESS_SPACE_LOCAL: 1379 O << "local"; 1380 break; 1381 case ADDRESS_SPACE_GLOBAL: 1382 O << "global"; 1383 break; 1384 case ADDRESS_SPACE_CONST: 1385 O << "const"; 1386 break; 1387 case ADDRESS_SPACE_SHARED: 1388 O << "shared"; 1389 break; 1390 default: 1391 report_fatal_error("Bad address space found while emitting PTX: " + 1392 llvm::Twine(AddressSpace)); 1393 break; 1394 } 1395 } 1396 1397 std::string 1398 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { 1399 switch (Ty->getTypeID()) { 1400 case Type::IntegerTyID: { 1401 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth(); 1402 if (NumBits == 1) 1403 return "pred"; 1404 else if (NumBits <= 64) { 1405 std::string name = "u"; 1406 return name + utostr(NumBits); 1407 } else { 1408 llvm_unreachable("Integer too large"); 1409 break; 1410 } 1411 break; 1412 } 1413 case Type::BFloatTyID: 1414 case Type::HalfTyID: 1415 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53 1416 // PTX assembly. 1417 return "b16"; 1418 case Type::FloatTyID: 1419 return "f32"; 1420 case Type::DoubleTyID: 1421 return "f64"; 1422 case Type::PointerTyID: { 1423 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace()); 1424 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size"); 1425 1426 if (PtrSize == 64) 1427 if (useB4PTR) 1428 return "b64"; 1429 else 1430 return "u64"; 1431 else if (useB4PTR) 1432 return "b32"; 1433 else 1434 return "u32"; 1435 } 1436 default: 1437 break; 1438 } 1439 llvm_unreachable("unexpected type"); 1440 } 1441 1442 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, 1443 raw_ostream &O, 1444 const NVPTXSubtarget &STI) { 1445 const DataLayout &DL = getDataLayout(); 1446 1447 // GlobalVariables are always constant pointers themselves. 1448 Type *ETy = GVar->getValueType(); 1449 1450 O << "."; 1451 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); 1452 if (isManaged(*GVar)) { 1453 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { 1454 report_fatal_error( 1455 ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); 1456 } 1457 O << " .attribute(.managed)"; 1458 } 1459 if (MaybeAlign A = GVar->getAlign()) 1460 O << " .align " << A->value(); 1461 else 1462 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); 1463 1464 // Special case for i128 1465 if (ETy->isIntegerTy(128)) { 1466 O << " .b8 "; 1467 getSymbol(GVar)->print(O, MAI); 1468 O << "[16]"; 1469 return; 1470 } 1471 1472 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) { 1473 O << " ."; 1474 O << getPTXFundamentalTypeStr(ETy); 1475 O << " "; 1476 getSymbol(GVar)->print(O, MAI); 1477 return; 1478 } 1479 1480 int64_t ElementSize = 0; 1481 1482 // Although PTX has direct support for struct type and array type and LLVM IR 1483 // is very similar to PTX, the LLVM CodeGen does not support for targets that 1484 // support these high level field accesses. Structs and arrays are lowered 1485 // into arrays of bytes. 1486 switch (ETy->getTypeID()) { 1487 case Type::StructTyID: 1488 case Type::ArrayTyID: 1489 case Type::FixedVectorTyID: 1490 ElementSize = DL.getTypeStoreSize(ETy); 1491 O << " .b8 "; 1492 getSymbol(GVar)->print(O, MAI); 1493 O << "["; 1494 if (ElementSize) { 1495 O << ElementSize; 1496 } 1497 O << "]"; 1498 break; 1499 default: 1500 llvm_unreachable("type not supported yet"); 1501 } 1502 } 1503 1504 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { 1505 const DataLayout &DL = getDataLayout(); 1506 const AttributeList &PAL = F->getAttributes(); 1507 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F); 1508 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering()); 1509 const NVPTXMachineFunctionInfo *MFI = 1510 MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr; 1511 1512 Function::const_arg_iterator I, E; 1513 unsigned paramIndex = 0; 1514 bool first = true; 1515 bool isKernelFunc = isKernelFunction(*F); 1516 bool isABI = (STI.getSmVersion() >= 20); 1517 1518 if (F->arg_empty() && !F->isVarArg()) { 1519 O << "()"; 1520 return; 1521 } 1522 1523 O << "(\n"; 1524 1525 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) { 1526 Type *Ty = I->getType(); 1527 1528 if (!first) 1529 O << ",\n"; 1530 1531 first = false; 1532 1533 // Handle image/sampler parameters 1534 if (isKernelFunc) { 1535 if (isSampler(*I) || isImage(*I)) { 1536 std::string ParamSym; 1537 raw_string_ostream ParamStr(ParamSym); 1538 ParamStr << F->getName() << "_param_" << paramIndex; 1539 ParamStr.flush(); 1540 bool EmitImagePtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym); 1541 if (isImage(*I)) { 1542 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) { 1543 if (EmitImagePtr) 1544 O << "\t.param .u64 .ptr .surfref "; 1545 else 1546 O << "\t.param .surfref "; 1547 O << TLI->getParamName(F, paramIndex); 1548 } 1549 else { // Default image is read_only 1550 if (EmitImagePtr) 1551 O << "\t.param .u64 .ptr .texref "; 1552 else 1553 O << "\t.param .texref "; 1554 O << TLI->getParamName(F, paramIndex); 1555 } 1556 } else { 1557 if (EmitImagePtr) 1558 O << "\t.param .u64 .ptr .samplerref "; 1559 else 1560 O << "\t.param .samplerref "; 1561 O << TLI->getParamName(F, paramIndex); 1562 } 1563 continue; 1564 } 1565 } 1566 1567 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F, 1568 paramIndex](Type *Ty) -> Align { 1569 if (MaybeAlign StackAlign = 1570 getAlign(*F, paramIndex + AttributeList::FirstArgIndex)) 1571 return StackAlign.value(); 1572 1573 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL); 1574 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex); 1575 return std::max(TypeAlign, ParamAlign.valueOrOne()); 1576 }; 1577 1578 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { 1579 if (ShouldPassAsArray(Ty)) { 1580 // Just print .param .align <a> .b8 .param[size]; 1581 // <a> = optimal alignment for the element type; always multiple of 1582 // PAL.getParamAlignment 1583 // size = typeallocsize of element type 1584 Align OptimalAlign = getOptimalAlignForParam(Ty); 1585 1586 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1587 O << TLI->getParamName(F, paramIndex); 1588 O << "[" << DL.getTypeAllocSize(Ty) << "]"; 1589 1590 continue; 1591 } 1592 // Just a scalar 1593 auto *PTy = dyn_cast<PointerType>(Ty); 1594 unsigned PTySizeInBits = 0; 1595 if (PTy) { 1596 PTySizeInBits = 1597 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits(); 1598 assert(PTySizeInBits && "Invalid pointer size"); 1599 } 1600 1601 if (isKernelFunc) { 1602 if (PTy) { 1603 O << "\t.param .u" << PTySizeInBits << " .ptr"; 1604 1605 switch (PTy->getAddressSpace()) { 1606 default: 1607 break; 1608 case ADDRESS_SPACE_GLOBAL: 1609 O << " .global"; 1610 break; 1611 case ADDRESS_SPACE_SHARED: 1612 O << " .shared"; 1613 break; 1614 case ADDRESS_SPACE_CONST: 1615 O << " .const"; 1616 break; 1617 case ADDRESS_SPACE_LOCAL: 1618 O << " .local"; 1619 break; 1620 } 1621 1622 O << " .align " << I->getParamAlign().valueOrOne().value(); 1623 O << " " << TLI->getParamName(F, paramIndex); 1624 continue; 1625 } 1626 1627 // non-pointer scalar to kernel func 1628 O << "\t.param ."; 1629 // Special case: predicate operands become .u8 types 1630 if (Ty->isIntegerTy(1)) 1631 O << "u8"; 1632 else 1633 O << getPTXFundamentalTypeStr(Ty); 1634 O << " "; 1635 O << TLI->getParamName(F, paramIndex); 1636 continue; 1637 } 1638 // Non-kernel function, just print .param .b<size> for ABI 1639 // and .reg .b<size> for non-ABI 1640 unsigned sz = 0; 1641 if (isa<IntegerType>(Ty)) { 1642 sz = cast<IntegerType>(Ty)->getBitWidth(); 1643 sz = promoteScalarArgumentSize(sz); 1644 } else if (PTy) { 1645 assert(PTySizeInBits && "Invalid pointer size"); 1646 sz = PTySizeInBits; 1647 } else 1648 sz = Ty->getPrimitiveSizeInBits(); 1649 if (isABI) 1650 O << "\t.param .b" << sz << " "; 1651 else 1652 O << "\t.reg .b" << sz << " "; 1653 O << TLI->getParamName(F, paramIndex); 1654 continue; 1655 } 1656 1657 // param has byVal attribute. 1658 Type *ETy = PAL.getParamByValType(paramIndex); 1659 assert(ETy && "Param should have byval type"); 1660 1661 if (isABI || isKernelFunc) { 1662 // Just print .param .align <a> .b8 .param[size]; 1663 // <a> = optimal alignment for the element type; always multiple of 1664 // PAL.getParamAlignment 1665 // size = typeallocsize of element type 1666 Align OptimalAlign = 1667 isKernelFunc 1668 ? getOptimalAlignForParam(ETy) 1669 : TLI->getFunctionByValParamAlign( 1670 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL); 1671 1672 unsigned sz = DL.getTypeAllocSize(ETy); 1673 O << "\t.param .align " << OptimalAlign.value() << " .b8 "; 1674 O << TLI->getParamName(F, paramIndex); 1675 O << "[" << sz << "]"; 1676 continue; 1677 } else { 1678 // Split the ETy into constituent parts and 1679 // print .param .b<size> <name> for each part. 1680 // Further, if a part is vector, print the above for 1681 // each vector element. 1682 SmallVector<EVT, 16> vtparts; 1683 ComputeValueVTs(*TLI, DL, ETy, vtparts); 1684 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 1685 unsigned elems = 1; 1686 EVT elemtype = vtparts[i]; 1687 if (vtparts[i].isVector()) { 1688 elems = vtparts[i].getVectorNumElements(); 1689 elemtype = vtparts[i].getVectorElementType(); 1690 } 1691 1692 for (unsigned j = 0, je = elems; j != je; ++j) { 1693 unsigned sz = elemtype.getSizeInBits(); 1694 if (elemtype.isInteger()) 1695 sz = promoteScalarArgumentSize(sz); 1696 O << "\t.reg .b" << sz << " "; 1697 O << TLI->getParamName(F, paramIndex); 1698 if (j < je - 1) 1699 O << ",\n"; 1700 ++paramIndex; 1701 } 1702 if (i < e - 1) 1703 O << ",\n"; 1704 } 1705 --paramIndex; 1706 continue; 1707 } 1708 } 1709 1710 if (F->isVarArg()) { 1711 if (!first) 1712 O << ",\n"; 1713 O << "\t.param .align " << STI.getMaxRequiredAlignment(); 1714 O << " .b8 "; 1715 O << TLI->getParamName(F, /* vararg */ -1) << "[]"; 1716 } 1717 1718 O << "\n)"; 1719 } 1720 1721 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( 1722 const MachineFunction &MF) { 1723 SmallString<128> Str; 1724 raw_svector_ostream O(Str); 1725 1726 // Map the global virtual register number to a register class specific 1727 // virtual register number starting from 1 with that class. 1728 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); 1729 //unsigned numRegClasses = TRI->getNumRegClasses(); 1730 1731 // Emit the Fake Stack Object 1732 const MachineFrameInfo &MFI = MF.getFrameInfo(); 1733 int64_t NumBytes = MFI.getStackSize(); 1734 if (NumBytes) { 1735 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t" 1736 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n"; 1737 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) { 1738 O << "\t.reg .b64 \t%SP;\n"; 1739 O << "\t.reg .b64 \t%SPL;\n"; 1740 } else { 1741 O << "\t.reg .b32 \t%SP;\n"; 1742 O << "\t.reg .b32 \t%SPL;\n"; 1743 } 1744 } 1745 1746 // Go through all virtual registers to establish the mapping between the 1747 // global virtual 1748 // register number and the per class virtual register number. 1749 // We use the per class virtual register number in the ptx output. 1750 unsigned int numVRs = MRI->getNumVirtRegs(); 1751 for (unsigned i = 0; i < numVRs; i++) { 1752 Register vr = Register::index2VirtReg(i); 1753 const TargetRegisterClass *RC = MRI->getRegClass(vr); 1754 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1755 int n = regmap.size(); 1756 regmap.insert(std::make_pair(vr, n + 1)); 1757 } 1758 1759 // Emit register declarations 1760 // @TODO: Extract out the real register usage 1761 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; 1762 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; 1763 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; 1764 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; 1765 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n"; 1766 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; 1767 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n"; 1768 1769 // Emit declaration of the virtual registers or 'physical' registers for 1770 // each register class 1771 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { 1772 const TargetRegisterClass *RC = TRI->getRegClass(i); 1773 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; 1774 std::string rcname = getNVPTXRegClassName(RC); 1775 std::string rcStr = getNVPTXRegClassStr(RC); 1776 int n = regmap.size(); 1777 1778 // Only declare those registers that may be used. 1779 if (n) { 1780 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) 1781 << ">;\n"; 1782 } 1783 } 1784 1785 OutStreamer->emitRawText(O.str()); 1786 } 1787 1788 /// Translate virtual register numbers in DebugInfo locations to their printed 1789 /// encodings, as used by CUDA-GDB. 1790 void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers( 1791 const MachineFunction &MF) { 1792 const NVPTXSubtarget &STI = MF.getSubtarget<NVPTXSubtarget>(); 1793 const NVPTXRegisterInfo *registerInfo = STI.getRegisterInfo(); 1794 1795 // Clear the old mapping, and add the new one. This mapping is used after the 1796 // printing of the current function is complete, but before the next function 1797 // is printed. 1798 registerInfo->clearDebugRegisterMap(); 1799 1800 for (auto &classMap : VRegMapping) { 1801 for (auto ®isterMapping : classMap.getSecond()) { 1802 auto reg = registerMapping.getFirst(); 1803 registerInfo->addToDebugRegisterMap(reg, getVirtualRegisterName(reg)); 1804 } 1805 } 1806 } 1807 1808 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) { 1809 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy 1810 bool ignored; 1811 unsigned int numHex; 1812 const char *lead; 1813 1814 if (Fp->getType()->getTypeID() == Type::FloatTyID) { 1815 numHex = 8; 1816 lead = "0f"; 1817 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored); 1818 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) { 1819 numHex = 16; 1820 lead = "0d"; 1821 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored); 1822 } else 1823 llvm_unreachable("unsupported fp type"); 1824 1825 APInt API = APF.bitcastToAPInt(); 1826 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true); 1827 } 1828 1829 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { 1830 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1831 O << CI->getValue(); 1832 return; 1833 } 1834 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) { 1835 printFPConstant(CFP, O); 1836 return; 1837 } 1838 if (isa<ConstantPointerNull>(CPV)) { 1839 O << "0"; 1840 return; 1841 } 1842 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1843 bool IsNonGenericPointer = false; 1844 if (GVar->getType()->getAddressSpace() != 0) { 1845 IsNonGenericPointer = true; 1846 } 1847 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) { 1848 O << "generic("; 1849 getSymbol(GVar)->print(O, MAI); 1850 O << ")"; 1851 } else { 1852 getSymbol(GVar)->print(O, MAI); 1853 } 1854 return; 1855 } 1856 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1857 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false); 1858 printMCExpr(*E, O); 1859 return; 1860 } 1861 llvm_unreachable("Not scalar type found in printScalarConstant()"); 1862 } 1863 1864 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, 1865 AggBuffer *AggBuffer) { 1866 const DataLayout &DL = getDataLayout(); 1867 int AllocSize = DL.getTypeAllocSize(CPV->getType()); 1868 if (isa<UndefValue>(CPV) || CPV->isNullValue()) { 1869 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise, 1870 // only the space allocated by CPV. 1871 AggBuffer->addZeros(Bytes ? Bytes : AllocSize); 1872 return; 1873 } 1874 1875 // Helper for filling AggBuffer with APInts. 1876 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) { 1877 size_t NumBytes = (Val.getBitWidth() + 7) / 8; 1878 SmallVector<unsigned char, 16> Buf(NumBytes); 1879 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the 1880 // input's bit width, and i1 arrays may not have a length that is a multuple 1881 // of 8. We handle the last byte separately, so we never request out of 1882 // bounds bits. 1883 for (unsigned I = 0; I < NumBytes - 1; ++I) { 1884 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8); 1885 } 1886 size_t LastBytePosition = (NumBytes - 1) * 8; 1887 size_t LastByteBits = Val.getBitWidth() - LastBytePosition; 1888 Buf[NumBytes - 1] = 1889 Val.extractBitsAsZExtValue(LastByteBits, LastBytePosition); 1890 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes); 1891 }; 1892 1893 switch (CPV->getType()->getTypeID()) { 1894 case Type::IntegerTyID: 1895 if (const auto CI = dyn_cast<ConstantInt>(CPV)) { 1896 AddIntToBuffer(CI->getValue()); 1897 break; 1898 } 1899 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1900 if (const auto *CI = 1901 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) { 1902 AddIntToBuffer(CI->getValue()); 1903 break; 1904 } 1905 if (Cexpr->getOpcode() == Instruction::PtrToInt) { 1906 Value *V = Cexpr->getOperand(0)->stripPointerCasts(); 1907 AggBuffer->addSymbol(V, Cexpr->getOperand(0)); 1908 AggBuffer->addZeros(AllocSize); 1909 break; 1910 } 1911 } 1912 llvm_unreachable("unsupported integer const type"); 1913 break; 1914 1915 case Type::HalfTyID: 1916 case Type::BFloatTyID: 1917 case Type::FloatTyID: 1918 case Type::DoubleTyID: 1919 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt()); 1920 break; 1921 1922 case Type::PointerTyID: { 1923 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { 1924 AggBuffer->addSymbol(GVar, GVar); 1925 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { 1926 const Value *v = Cexpr->stripPointerCasts(); 1927 AggBuffer->addSymbol(v, Cexpr); 1928 } 1929 AggBuffer->addZeros(AllocSize); 1930 break; 1931 } 1932 1933 case Type::ArrayTyID: 1934 case Type::FixedVectorTyID: 1935 case Type::StructTyID: { 1936 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) { 1937 bufferAggregateConstant(CPV, AggBuffer); 1938 if (Bytes > AllocSize) 1939 AggBuffer->addZeros(Bytes - AllocSize); 1940 } else if (isa<ConstantAggregateZero>(CPV)) 1941 AggBuffer->addZeros(Bytes); 1942 else 1943 llvm_unreachable("Unexpected Constant type"); 1944 break; 1945 } 1946 1947 default: 1948 llvm_unreachable("unsupported type"); 1949 } 1950 } 1951 1952 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, 1953 AggBuffer *aggBuffer) { 1954 const DataLayout &DL = getDataLayout(); 1955 int Bytes; 1956 1957 // Integers of arbitrary width 1958 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { 1959 APInt Val = CI->getValue(); 1960 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) { 1961 uint8_t Byte = Val.getLoBits(8).getZExtValue(); 1962 aggBuffer->addBytes(&Byte, 1, 1); 1963 Val.lshrInPlace(8); 1964 } 1965 return; 1966 } 1967 1968 // Old constants 1969 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) { 1970 if (CPV->getNumOperands()) 1971 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) 1972 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer); 1973 return; 1974 } 1975 1976 if (const ConstantDataSequential *CDS = 1977 dyn_cast<ConstantDataSequential>(CPV)) { 1978 if (CDS->getNumElements()) 1979 for (unsigned i = 0; i < CDS->getNumElements(); ++i) 1980 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0, 1981 aggBuffer); 1982 return; 1983 } 1984 1985 if (isa<ConstantStruct>(CPV)) { 1986 if (CPV->getNumOperands()) { 1987 StructType *ST = cast<StructType>(CPV->getType()); 1988 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) { 1989 if (i == (e - 1)) 1990 Bytes = DL.getStructLayout(ST)->getElementOffset(0) + 1991 DL.getTypeAllocSize(ST) - 1992 DL.getStructLayout(ST)->getElementOffset(i); 1993 else 1994 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) - 1995 DL.getStructLayout(ST)->getElementOffset(i); 1996 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer); 1997 } 1998 } 1999 return; 2000 } 2001 llvm_unreachable("unsupported constant type in printAggregateConstant()"); 2002 } 2003 2004 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly 2005 /// a copy from AsmPrinter::lowerConstant, except customized to only handle 2006 /// expressions that are representable in PTX and create 2007 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions. 2008 const MCExpr * 2009 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) { 2010 MCContext &Ctx = OutContext; 2011 2012 if (CV->isNullValue() || isa<UndefValue>(CV)) 2013 return MCConstantExpr::create(0, Ctx); 2014 2015 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) 2016 return MCConstantExpr::create(CI->getZExtValue(), Ctx); 2017 2018 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) { 2019 const MCSymbolRefExpr *Expr = 2020 MCSymbolRefExpr::create(getSymbol(GV), Ctx); 2021 if (ProcessingGeneric) { 2022 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx); 2023 } else { 2024 return Expr; 2025 } 2026 } 2027 2028 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV); 2029 if (!CE) { 2030 llvm_unreachable("Unknown constant value to lower!"); 2031 } 2032 2033 switch (CE->getOpcode()) { 2034 default: 2035 break; // Error 2036 2037 case Instruction::AddrSpaceCast: { 2038 // Strip the addrspacecast and pass along the operand 2039 PointerType *DstTy = cast<PointerType>(CE->getType()); 2040 if (DstTy->getAddressSpace() == 0) 2041 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true); 2042 2043 break; // Error 2044 } 2045 2046 case Instruction::GetElementPtr: { 2047 const DataLayout &DL = getDataLayout(); 2048 2049 // Generate a symbolic expression for the byte address 2050 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0); 2051 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI); 2052 2053 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0), 2054 ProcessingGeneric); 2055 if (!OffsetAI) 2056 return Base; 2057 2058 int64_t Offset = OffsetAI.getSExtValue(); 2059 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx), 2060 Ctx); 2061 } 2062 2063 case Instruction::Trunc: 2064 // We emit the value and depend on the assembler to truncate the generated 2065 // expression properly. This is important for differences between 2066 // blockaddress labels. Since the two labels are in the same function, it 2067 // is reasonable to treat their delta as a 32-bit value. 2068 [[fallthrough]]; 2069 case Instruction::BitCast: 2070 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2071 2072 case Instruction::IntToPtr: { 2073 const DataLayout &DL = getDataLayout(); 2074 2075 // Handle casts to pointers by changing them into casts to the appropriate 2076 // integer type. This promotes constant folding and simplifies this code. 2077 Constant *Op = CE->getOperand(0); 2078 Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()), 2079 /*IsSigned*/ false, DL); 2080 if (Op) 2081 return lowerConstantForGV(Op, ProcessingGeneric); 2082 2083 break; // Error 2084 } 2085 2086 case Instruction::PtrToInt: { 2087 const DataLayout &DL = getDataLayout(); 2088 2089 // Support only foldable casts to/from pointers that can be eliminated by 2090 // changing the pointer to the appropriately sized integer type. 2091 Constant *Op = CE->getOperand(0); 2092 Type *Ty = CE->getType(); 2093 2094 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric); 2095 2096 // We can emit the pointer value into this slot if the slot is an 2097 // integer slot equal to the size of the pointer. 2098 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType())) 2099 return OpExpr; 2100 2101 // Otherwise the pointer is smaller than the resultant integer, mask off 2102 // the high bits so we are sure to get a proper truncation if the input is 2103 // a constant expr. 2104 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType()); 2105 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx); 2106 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx); 2107 } 2108 2109 // The MC library also has a right-shift operator, but it isn't consistently 2110 // signed or unsigned between different targets. 2111 case Instruction::Add: { 2112 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric); 2113 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric); 2114 switch (CE->getOpcode()) { 2115 default: llvm_unreachable("Unknown binary operator constant cast expr"); 2116 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx); 2117 } 2118 } 2119 } 2120 2121 // If the code isn't optimized, there may be outstanding folding 2122 // opportunities. Attempt to fold the expression using DataLayout as a 2123 // last resort before giving up. 2124 Constant *C = ConstantFoldConstant(CE, getDataLayout()); 2125 if (C != CE) 2126 return lowerConstantForGV(C, ProcessingGeneric); 2127 2128 // Otherwise report the problem to the user. 2129 std::string S; 2130 raw_string_ostream OS(S); 2131 OS << "Unsupported expression in static initializer: "; 2132 CE->printAsOperand(OS, /*PrintType=*/false, 2133 !MF ? nullptr : MF->getFunction().getParent()); 2134 report_fatal_error(Twine(OS.str())); 2135 } 2136 2137 // Copy of MCExpr::print customized for NVPTX 2138 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) { 2139 switch (Expr.getKind()) { 2140 case MCExpr::Target: 2141 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI); 2142 case MCExpr::Constant: 2143 OS << cast<MCConstantExpr>(Expr).getValue(); 2144 return; 2145 2146 case MCExpr::SymbolRef: { 2147 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr); 2148 const MCSymbol &Sym = SRE.getSymbol(); 2149 Sym.print(OS, MAI); 2150 return; 2151 } 2152 2153 case MCExpr::Unary: { 2154 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr); 2155 switch (UE.getOpcode()) { 2156 case MCUnaryExpr::LNot: OS << '!'; break; 2157 case MCUnaryExpr::Minus: OS << '-'; break; 2158 case MCUnaryExpr::Not: OS << '~'; break; 2159 case MCUnaryExpr::Plus: OS << '+'; break; 2160 } 2161 printMCExpr(*UE.getSubExpr(), OS); 2162 return; 2163 } 2164 2165 case MCExpr::Binary: { 2166 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr); 2167 2168 // Only print parens around the LHS if it is non-trivial. 2169 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) || 2170 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) { 2171 printMCExpr(*BE.getLHS(), OS); 2172 } else { 2173 OS << '('; 2174 printMCExpr(*BE.getLHS(), OS); 2175 OS<< ')'; 2176 } 2177 2178 switch (BE.getOpcode()) { 2179 case MCBinaryExpr::Add: 2180 // Print "X-42" instead of "X+-42". 2181 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) { 2182 if (RHSC->getValue() < 0) { 2183 OS << RHSC->getValue(); 2184 return; 2185 } 2186 } 2187 2188 OS << '+'; 2189 break; 2190 default: llvm_unreachable("Unhandled binary operator"); 2191 } 2192 2193 // Only print parens around the LHS if it is non-trivial. 2194 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) { 2195 printMCExpr(*BE.getRHS(), OS); 2196 } else { 2197 OS << '('; 2198 printMCExpr(*BE.getRHS(), OS); 2199 OS << ')'; 2200 } 2201 return; 2202 } 2203 } 2204 2205 llvm_unreachable("Invalid expression kind!"); 2206 } 2207 2208 /// PrintAsmOperand - Print out an operand for an inline asm expression. 2209 /// 2210 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 2211 const char *ExtraCode, raw_ostream &O) { 2212 if (ExtraCode && ExtraCode[0]) { 2213 if (ExtraCode[1] != 0) 2214 return true; // Unknown modifier. 2215 2216 switch (ExtraCode[0]) { 2217 default: 2218 // See if this is a generic print operand 2219 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O); 2220 case 'r': 2221 break; 2222 } 2223 } 2224 2225 printOperand(MI, OpNo, O); 2226 2227 return false; 2228 } 2229 2230 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI, 2231 unsigned OpNo, 2232 const char *ExtraCode, 2233 raw_ostream &O) { 2234 if (ExtraCode && ExtraCode[0]) 2235 return true; // Unknown modifier 2236 2237 O << '['; 2238 printMemOperand(MI, OpNo, O); 2239 O << ']'; 2240 2241 return false; 2242 } 2243 2244 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum, 2245 raw_ostream &O) { 2246 const MachineOperand &MO = MI->getOperand(OpNum); 2247 switch (MO.getType()) { 2248 case MachineOperand::MO_Register: 2249 if (MO.getReg().isPhysical()) { 2250 if (MO.getReg() == NVPTX::VRDepot) 2251 O << DEPOTNAME << getFunctionNumber(); 2252 else 2253 O << NVPTXInstPrinter::getRegisterName(MO.getReg()); 2254 } else { 2255 emitVirtualRegister(MO.getReg(), O); 2256 } 2257 break; 2258 2259 case MachineOperand::MO_Immediate: 2260 O << MO.getImm(); 2261 break; 2262 2263 case MachineOperand::MO_FPImmediate: 2264 printFPConstant(MO.getFPImm(), O); 2265 break; 2266 2267 case MachineOperand::MO_GlobalAddress: 2268 PrintSymbolOperand(MO, O); 2269 break; 2270 2271 case MachineOperand::MO_MachineBasicBlock: 2272 MO.getMBB()->getSymbol()->print(O, MAI); 2273 break; 2274 2275 default: 2276 llvm_unreachable("Operand type not supported."); 2277 } 2278 } 2279 2280 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum, 2281 raw_ostream &O, const char *Modifier) { 2282 printOperand(MI, OpNum, O); 2283 2284 if (Modifier && strcmp(Modifier, "add") == 0) { 2285 O << ", "; 2286 printOperand(MI, OpNum + 1, O); 2287 } else { 2288 if (MI->getOperand(OpNum + 1).isImm() && 2289 MI->getOperand(OpNum + 1).getImm() == 0) 2290 return; // don't print ',0' or '+0' 2291 O << "+"; 2292 printOperand(MI, OpNum + 1, O); 2293 } 2294 } 2295 2296 // Force static initialization. 2297 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() { 2298 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32()); 2299 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64()); 2300 } 2301