xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp (revision 4583f6d3443c8dc6605c868724e3743161954210)
1 //===-- NVPTXCtorDtorLowering.cpp - Handle global ctors and dtors --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This pass creates a unified init and fini kernel with the required metadata
11 //===----------------------------------------------------------------------===//
12 
13 #include "NVPTXCtorDtorLowering.h"
14 #include "MCTargetDesc/NVPTXBaseInfo.h"
15 #include "NVPTX.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/IR/CallingConv.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Function.h"
20 #include "llvm/IR/GlobalVariable.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/MD5.h"
27 #include "llvm/Transforms/Utils/ModuleUtils.h"
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "nvptx-lower-ctor-dtor"
32 
33 static cl::opt<std::string>
34     GlobalStr("nvptx-lower-global-ctor-dtor-id",
35               cl::desc("Override unique ID of ctor/dtor globals."),
36               cl::init(""), cl::Hidden);
37 
38 static cl::opt<bool>
39     CreateKernels("nvptx-emit-init-fini-kernel",
40                   cl::desc("Emit kernels to call ctor/dtor globals."),
41                   cl::init(true), cl::Hidden);
42 
43 namespace {
44 
45 static std::string getHash(StringRef Str) {
46   llvm::MD5 Hasher;
47   llvm::MD5::MD5Result Hash;
48   Hasher.update(Str);
49   Hasher.final(Hash);
50   return llvm::utohexstr(Hash.low(), /*LowerCase=*/true);
51 }
52 
53 static void addKernelMetadata(Module &M, Function *F) {
54   llvm::LLVMContext &Ctx = M.getContext();
55 
56   // Get "nvvm.annotations" metadata node.
57   llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
58 
59   // This kernel is only to be called single-threaded.
60   llvm::Metadata *ThreadXMDVals[] = {
61       llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidx"),
62       llvm::ConstantAsMetadata::get(
63           llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
64   llvm::Metadata *ThreadYMDVals[] = {
65       llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidy"),
66       llvm::ConstantAsMetadata::get(
67           llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
68   llvm::Metadata *ThreadZMDVals[] = {
69       llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "maxntidz"),
70       llvm::ConstantAsMetadata::get(
71           llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
72 
73   llvm::Metadata *BlockMDVals[] = {
74       llvm::ConstantAsMetadata::get(F),
75       llvm::MDString::get(Ctx, "maxclusterrank"),
76       llvm::ConstantAsMetadata::get(
77           llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
78 
79   // Append metadata to nvvm.annotations.
80   F->setCallingConv(CallingConv::PTX_Kernel);
81   MD->addOperand(llvm::MDNode::get(Ctx, ThreadXMDVals));
82   MD->addOperand(llvm::MDNode::get(Ctx, ThreadYMDVals));
83   MD->addOperand(llvm::MDNode::get(Ctx, ThreadZMDVals));
84   MD->addOperand(llvm::MDNode::get(Ctx, BlockMDVals));
85 }
86 
87 static Function *createInitOrFiniKernelFunction(Module &M, bool IsCtor) {
88   StringRef InitOrFiniKernelName =
89       IsCtor ? "nvptx$device$init" : "nvptx$device$fini";
90   if (M.getFunction(InitOrFiniKernelName))
91     return nullptr;
92 
93   Function *InitOrFiniKernel = Function::createWithDefaultAttr(
94       FunctionType::get(Type::getVoidTy(M.getContext()), false),
95       GlobalValue::WeakODRLinkage, 0, InitOrFiniKernelName, &M);
96   addKernelMetadata(M, InitOrFiniKernel);
97 
98   return InitOrFiniKernel;
99 }
100 
101 // We create the IR required to call each callback in this section. This is
102 // equivalent to the following code. Normally, the linker would provide us with
103 // the definitions of the init and fini array sections. The 'nvlink' linker does
104 // not do this so initializing these values is done by the runtime.
105 //
106 // extern "C" void **__init_array_start = nullptr;
107 // extern "C" void **__init_array_end = nullptr;
108 // extern "C" void **__fini_array_start = nullptr;
109 // extern "C" void **__fini_array_end = nullptr;
110 //
111 // using InitCallback = void();
112 // using FiniCallback = void();
113 //
114 // void call_init_array_callbacks() {
115 //   for (auto start = __init_array_start; start != __init_array_end; ++start)
116 //     reinterpret_cast<InitCallback *>(*start)();
117 // }
118 //
119 // void call_init_array_callbacks() {
120 //   size_t fini_array_size = __fini_array_end - __fini_array_start;
121 //   for (size_t i = fini_array_size; i > 0; --i)
122 //     reinterpret_cast<FiniCallback *>(__fini_array_start[i - 1])();
123 // }
124 static void createInitOrFiniCalls(Function &F, bool IsCtor) {
125   Module &M = *F.getParent();
126   LLVMContext &C = M.getContext();
127 
128   IRBuilder<> IRB(BasicBlock::Create(C, "entry", &F));
129   auto *LoopBB = BasicBlock::Create(C, "while.entry", &F);
130   auto *ExitBB = BasicBlock::Create(C, "while.end", &F);
131   Type *PtrTy = IRB.getPtrTy(llvm::ADDRESS_SPACE_GLOBAL);
132 
133   auto *Begin = M.getOrInsertGlobal(
134       IsCtor ? "__init_array_start" : "__fini_array_start",
135       PointerType::get(C, 0), [&]() {
136         auto *GV = new GlobalVariable(
137             M, PointerType::get(C, 0),
138             /*isConstant=*/false, GlobalValue::WeakAnyLinkage,
139             Constant::getNullValue(PointerType::get(C, 0)),
140             IsCtor ? "__init_array_start" : "__fini_array_start",
141             /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
142             /*AddressSpace=*/llvm::ADDRESS_SPACE_GLOBAL);
143         GV->setVisibility(GlobalVariable::ProtectedVisibility);
144         return GV;
145       });
146   auto *End = M.getOrInsertGlobal(
147       IsCtor ? "__init_array_end" : "__fini_array_end", PointerType::get(C, 0),
148       [&]() {
149         auto *GV = new GlobalVariable(
150             M, PointerType::get(C, 0),
151             /*isConstant=*/false, GlobalValue::WeakAnyLinkage,
152             Constant::getNullValue(PointerType::get(C, 0)),
153             IsCtor ? "__init_array_end" : "__fini_array_end",
154             /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
155             /*AddressSpace=*/llvm::ADDRESS_SPACE_GLOBAL);
156         GV->setVisibility(GlobalVariable::ProtectedVisibility);
157         return GV;
158       });
159 
160   // The constructor type is suppoed to allow using the argument vectors, but
161   // for now we just call them with no arguments.
162   auto *CallBackTy = FunctionType::get(IRB.getVoidTy(), {});
163 
164   // The destructor array must be called in reverse order. Get an expression to
165   // the end of the array and iterate backwards in that case.
166   Value *BeginVal = IRB.CreateLoad(Begin->getType(), Begin, "begin");
167   Value *EndVal = IRB.CreateLoad(Begin->getType(), End, "stop");
168   if (!IsCtor) {
169     auto *BeginInt = IRB.CreatePtrToInt(BeginVal, IntegerType::getInt64Ty(C));
170     auto *EndInt = IRB.CreatePtrToInt(EndVal, IntegerType::getInt64Ty(C));
171     auto *SubInst = IRB.CreateSub(EndInt, BeginInt);
172     auto *Offset = IRB.CreateAShr(
173         SubInst, ConstantInt::get(IntegerType::getInt64Ty(C), 3), "offset",
174         /*IsExact=*/true);
175     auto *ValuePtr = IRB.CreateGEP(PointerType::get(C, 0), BeginVal,
176                                    ArrayRef<Value *>({Offset}));
177     EndVal = BeginVal;
178     BeginVal = IRB.CreateInBoundsGEP(
179         PointerType::get(C, 0), ValuePtr,
180         ArrayRef<Value *>(ConstantInt::get(IntegerType::getInt64Ty(C), -1)),
181         "start");
182   }
183   IRB.CreateCondBr(
184       IRB.CreateCmp(IsCtor ? ICmpInst::ICMP_NE : ICmpInst::ICMP_UGT, BeginVal,
185                     EndVal),
186       LoopBB, ExitBB);
187   IRB.SetInsertPoint(LoopBB);
188   auto *CallBackPHI = IRB.CreatePHI(PtrTy, 2, "ptr");
189   auto *CallBack = IRB.CreateLoad(IRB.getPtrTy(F.getAddressSpace()),
190                                   CallBackPHI, "callback");
191   IRB.CreateCall(CallBackTy, CallBack);
192   auto *NewCallBack =
193       IRB.CreateConstGEP1_64(PtrTy, CallBackPHI, IsCtor ? 1 : -1, "next");
194   auto *EndCmp = IRB.CreateCmp(IsCtor ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_ULT,
195                                NewCallBack, EndVal, "end");
196   CallBackPHI->addIncoming(BeginVal, &F.getEntryBlock());
197   CallBackPHI->addIncoming(NewCallBack, LoopBB);
198   IRB.CreateCondBr(EndCmp, ExitBB, LoopBB);
199   IRB.SetInsertPoint(ExitBB);
200   IRB.CreateRetVoid();
201 }
202 
203 static bool createInitOrFiniGlobals(Module &M, GlobalVariable *GV,
204                                     bool IsCtor) {
205   ConstantArray *GA = dyn_cast<ConstantArray>(GV->getInitializer());
206   if (!GA || GA->getNumOperands() == 0)
207     return false;
208 
209   // NVPTX has no way to emit variables at specific sections or support for
210   // the traditional constructor sections. Instead, we emit mangled global
211   // names so the runtime can build the list manually.
212   for (Value *V : GA->operands()) {
213     auto *CS = cast<ConstantStruct>(V);
214     auto *F = cast<Constant>(CS->getOperand(1));
215     uint64_t Priority = cast<ConstantInt>(CS->getOperand(0))->getSExtValue();
216     std::string PriorityStr = "." + std::to_string(Priority);
217     // We append a semi-unique hash and the priority to the global name.
218     std::string GlobalID =
219         !GlobalStr.empty() ? GlobalStr : getHash(M.getSourceFileName());
220     std::string NameStr =
221         ((IsCtor ? "__init_array_object_" : "__fini_array_object_") +
222          F->getName() + "_" + GlobalID + "_" + std::to_string(Priority))
223             .str();
224     // PTX does not support exported names with '.' in them.
225     llvm::transform(NameStr, NameStr.begin(),
226                     [](char c) { return c == '.' ? '_' : c; });
227 
228     auto *GV = new GlobalVariable(M, F->getType(), /*IsConstant=*/true,
229                                   GlobalValue::ExternalLinkage, F, NameStr,
230                                   nullptr, GlobalValue::NotThreadLocal,
231                                   /*AddressSpace=*/4);
232     // This isn't respected by Nvidia, simply put here for clarity.
233     GV->setSection(IsCtor ? ".init_array" + PriorityStr
234                           : ".fini_array" + PriorityStr);
235     GV->setVisibility(GlobalVariable::ProtectedVisibility);
236     appendToUsed(M, {GV});
237   }
238 
239   return true;
240 }
241 
242 static bool createInitOrFiniKernel(Module &M, StringRef GlobalName,
243                                    bool IsCtor) {
244   GlobalVariable *GV = M.getGlobalVariable(GlobalName);
245   if (!GV || !GV->hasInitializer())
246     return false;
247 
248   if (!createInitOrFiniGlobals(M, GV, IsCtor))
249     return false;
250 
251   if (!CreateKernels)
252     return true;
253 
254   Function *InitOrFiniKernel = createInitOrFiniKernelFunction(M, IsCtor);
255   if (!InitOrFiniKernel)
256     return false;
257 
258   createInitOrFiniCalls(*InitOrFiniKernel, IsCtor);
259 
260   GV->eraseFromParent();
261   return true;
262 }
263 
264 static bool lowerCtorsAndDtors(Module &M) {
265   bool Modified = false;
266   Modified |= createInitOrFiniKernel(M, "llvm.global_ctors", /*IsCtor =*/true);
267   Modified |= createInitOrFiniKernel(M, "llvm.global_dtors", /*IsCtor =*/false);
268   return Modified;
269 }
270 
271 class NVPTXCtorDtorLoweringLegacy final : public ModulePass {
272 public:
273   static char ID;
274   NVPTXCtorDtorLoweringLegacy() : ModulePass(ID) {}
275   bool runOnModule(Module &M) override { return lowerCtorsAndDtors(M); }
276 };
277 
278 } // End anonymous namespace
279 
280 PreservedAnalyses NVPTXCtorDtorLoweringPass::run(Module &M,
281                                                  ModuleAnalysisManager &AM) {
282   return lowerCtorsAndDtors(M) ? PreservedAnalyses::none()
283                                : PreservedAnalyses::all();
284 }
285 
286 char NVPTXCtorDtorLoweringLegacy::ID = 0;
287 char &llvm::NVPTXCtorDtorLoweringLegacyPassID = NVPTXCtorDtorLoweringLegacy::ID;
288 INITIALIZE_PASS(NVPTXCtorDtorLoweringLegacy, DEBUG_TYPE,
289                 "Lower ctors and dtors for NVPTX", false, false)
290 
291 ModulePass *llvm::createNVPTXCtorDtorLoweringLegacyPass() {
292   return new NVPTXCtorDtorLoweringLegacy();
293 }
294