xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp (revision 4583f6d3443c8dc6605c868724e3743161954210)
1 //===- NVPTXUtilities.cpp - Utility Functions -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains miscellaneous utility functions
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "NVPTXUtilities.h"
14 #include "NVPTX.h"
15 #include "NVPTXTargetMachine.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/Function.h"
19 #include "llvm/IR/GlobalVariable.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/Support/Alignment.h"
22 #include "llvm/Support/Mutex.h"
23 #include <cstring>
24 #include <map>
25 #include <mutex>
26 #include <optional>
27 #include <string>
28 #include <vector>
29 
30 namespace llvm {
31 
32 namespace {
33 typedef std::map<std::string, std::vector<unsigned>> key_val_pair_t;
34 typedef std::map<const GlobalValue *, key_val_pair_t> global_val_annot_t;
35 
36 struct AnnotationCache {
37   sys::Mutex Lock;
38   std::map<const Module *, global_val_annot_t> Cache;
39 };
40 
41 AnnotationCache &getAnnotationCache() {
42   static AnnotationCache AC;
43   return AC;
44 }
45 } // anonymous namespace
46 
47 void clearAnnotationCache(const Module *Mod) {
48   auto &AC = getAnnotationCache();
49   std::lock_guard<sys::Mutex> Guard(AC.Lock);
50   AC.Cache.erase(Mod);
51 }
52 
53 static void readIntVecFromMDNode(const MDNode *MetadataNode,
54                                  std::vector<unsigned> &Vec) {
55   for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
56     ConstantInt *Val =
57         mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
58     Vec.push_back(Val->getZExtValue());
59   }
60 }
61 
62 static void cacheAnnotationFromMD(const MDNode *MetadataNode,
63                                   key_val_pair_t &retval) {
64   auto &AC = getAnnotationCache();
65   std::lock_guard<sys::Mutex> Guard(AC.Lock);
66   assert(MetadataNode && "Invalid mdnode for annotation");
67   assert((MetadataNode->getNumOperands() % 2) == 1 &&
68          "Invalid number of operands");
69   // start index = 1, to skip the global variable key
70   // increment = 2, to skip the value for each property-value pairs
71   for (unsigned i = 1, e = MetadataNode->getNumOperands(); i != e; i += 2) {
72     // property
73     const MDString *prop = dyn_cast<MDString>(MetadataNode->getOperand(i));
74     assert(prop && "Annotation property not a string");
75     std::string Key = prop->getString().str();
76 
77     // value
78     if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
79             MetadataNode->getOperand(i + 1))) {
80       retval[Key].push_back(Val->getZExtValue());
81     } else if (MDNode *VecMd =
82                    dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
83       // note: only "grid_constant" annotations support vector MDNodes.
84       // assert: there can only exist one unique key value pair of
85       // the form (string key, MDNode node). Operands of such a node
86       // shall always be unsigned ints.
87       auto [It, Inserted] = retval.try_emplace(Key);
88       if (Inserted) {
89         readIntVecFromMDNode(VecMd, It->second);
90         continue;
91       }
92     } else {
93       llvm_unreachable("Value operand not a constant int or an mdnode");
94     }
95   }
96 }
97 
98 static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
99   auto &AC = getAnnotationCache();
100   std::lock_guard<sys::Mutex> Guard(AC.Lock);
101   NamedMDNode *NMD = m->getNamedMetadata("nvvm.annotations");
102   if (!NMD)
103     return;
104   key_val_pair_t tmp;
105   for (unsigned i = 0, e = NMD->getNumOperands(); i != e; ++i) {
106     const MDNode *elem = NMD->getOperand(i);
107 
108     GlobalValue *entity =
109         mdconst::dyn_extract_or_null<GlobalValue>(elem->getOperand(0));
110     // entity may be null due to DCE
111     if (!entity)
112       continue;
113     if (entity != gv)
114       continue;
115 
116     // accumulate annotations for entity in tmp
117     cacheAnnotationFromMD(elem, tmp);
118   }
119 
120   if (tmp.empty()) // no annotations for this gv
121     return;
122 
123   AC.Cache[m][gv] = std::move(tmp);
124 }
125 
126 static std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
127                                                      const std::string &prop) {
128   auto &AC = getAnnotationCache();
129   std::lock_guard<sys::Mutex> Guard(AC.Lock);
130   const Module *m = gv->getParent();
131   if (AC.Cache.find(m) == AC.Cache.end())
132     cacheAnnotationFromMD(m, gv);
133   else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
134     cacheAnnotationFromMD(m, gv);
135   if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
136     return std::nullopt;
137   return AC.Cache[m][gv][prop][0];
138 }
139 
140 static bool findAllNVVMAnnotation(const GlobalValue *gv,
141                                   const std::string &prop,
142                                   std::vector<unsigned> &retval) {
143   auto &AC = getAnnotationCache();
144   std::lock_guard<sys::Mutex> Guard(AC.Lock);
145   const Module *m = gv->getParent();
146   if (AC.Cache.find(m) == AC.Cache.end())
147     cacheAnnotationFromMD(m, gv);
148   else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
149     cacheAnnotationFromMD(m, gv);
150   if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
151     return false;
152   retval = AC.Cache[m][gv][prop];
153   return true;
154 }
155 
156 static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) {
157   if (const auto *GV = dyn_cast<GlobalValue>(&V))
158     if (const auto Annot = findOneNVVMAnnotation(GV, Prop)) {
159       assert((*Annot == 1) && "Unexpected annotation on a symbol");
160       return true;
161     }
162 
163   return false;
164 }
165 
166 static bool argHasNVVMAnnotation(const Value &Val,
167                                  const std::string &Annotation,
168                                  const bool StartArgIndexAtOne = false) {
169   if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
170     const Function *Func = Arg->getParent();
171     std::vector<unsigned> Annot;
172     if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
173       const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
174       if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
175         return true;
176       }
177     }
178   }
179   return false;
180 }
181 
182 bool isParamGridConstant(const Value &V) {
183   if (const Argument *Arg = dyn_cast<Argument>(&V)) {
184     // "grid_constant" counts argument indices starting from 1
185     if (Arg->hasByValAttr() &&
186         argHasNVVMAnnotation(*Arg, "grid_constant",
187                              /*StartArgIndexAtOne*/ true)) {
188       assert(isKernelFunction(*Arg->getParent()) &&
189              "only kernel arguments can be grid_constant");
190       return true;
191     }
192   }
193   return false;
194 }
195 
196 bool isTexture(const Value &V) { return globalHasNVVMAnnotation(V, "texture"); }
197 
198 bool isSurface(const Value &V) { return globalHasNVVMAnnotation(V, "surface"); }
199 
200 bool isSampler(const Value &V) {
201   const char *AnnotationName = "sampler";
202 
203   return globalHasNVVMAnnotation(V, AnnotationName) ||
204          argHasNVVMAnnotation(V, AnnotationName);
205 }
206 
207 bool isImageReadOnly(const Value &V) {
208   return argHasNVVMAnnotation(V, "rdoimage");
209 }
210 
211 bool isImageWriteOnly(const Value &V) {
212   return argHasNVVMAnnotation(V, "wroimage");
213 }
214 
215 bool isImageReadWrite(const Value &V) {
216   return argHasNVVMAnnotation(V, "rdwrimage");
217 }
218 
219 bool isImage(const Value &V) {
220   return isImageReadOnly(V) || isImageWriteOnly(V) || isImageReadWrite(V);
221 }
222 
223 bool isManaged(const Value &V) { return globalHasNVVMAnnotation(V, "managed"); }
224 
225 StringRef getTextureName(const Value &V) {
226   assert(V.hasName() && "Found texture variable with no name");
227   return V.getName();
228 }
229 
230 StringRef getSurfaceName(const Value &V) {
231   assert(V.hasName() && "Found surface variable with no name");
232   return V.getName();
233 }
234 
235 StringRef getSamplerName(const Value &V) {
236   assert(V.hasName() && "Found sampler variable with no name");
237   return V.getName();
238 }
239 
240 std::optional<unsigned> getMaxNTIDx(const Function &F) {
241   return findOneNVVMAnnotation(&F, "maxntidx");
242 }
243 
244 std::optional<unsigned> getMaxNTIDy(const Function &F) {
245   return findOneNVVMAnnotation(&F, "maxntidy");
246 }
247 
248 std::optional<unsigned> getMaxNTIDz(const Function &F) {
249   return findOneNVVMAnnotation(&F, "maxntidz");
250 }
251 
252 std::optional<unsigned> getMaxNTID(const Function &F) {
253   // Note: The semantics here are a bit strange. The PTX ISA states the
254   // following (11.4.2. Performance-Tuning Directives: .maxntid):
255   //
256   //  Note that this directive guarantees that the total number of threads does
257   //  not exceed the maximum, but does not guarantee that the limit in any
258   //  particular dimension is not exceeded.
259   std::optional<unsigned> MaxNTIDx = getMaxNTIDx(F);
260   std::optional<unsigned> MaxNTIDy = getMaxNTIDy(F);
261   std::optional<unsigned> MaxNTIDz = getMaxNTIDz(F);
262   if (MaxNTIDx || MaxNTIDy || MaxNTIDz)
263     return MaxNTIDx.value_or(1) * MaxNTIDy.value_or(1) * MaxNTIDz.value_or(1);
264   return std::nullopt;
265 }
266 
267 std::optional<unsigned> getClusterDimx(const Function &F) {
268   return findOneNVVMAnnotation(&F, "cluster_dim_x");
269 }
270 
271 std::optional<unsigned> getClusterDimy(const Function &F) {
272   return findOneNVVMAnnotation(&F, "cluster_dim_y");
273 }
274 
275 std::optional<unsigned> getClusterDimz(const Function &F) {
276   return findOneNVVMAnnotation(&F, "cluster_dim_z");
277 }
278 
279 std::optional<unsigned> getMaxClusterRank(const Function &F) {
280   return findOneNVVMAnnotation(&F, "maxclusterrank");
281 }
282 
283 std::optional<unsigned> getReqNTIDx(const Function &F) {
284   return findOneNVVMAnnotation(&F, "reqntidx");
285 }
286 
287 std::optional<unsigned> getReqNTIDy(const Function &F) {
288   return findOneNVVMAnnotation(&F, "reqntidy");
289 }
290 
291 std::optional<unsigned> getReqNTIDz(const Function &F) {
292   return findOneNVVMAnnotation(&F, "reqntidz");
293 }
294 
295 std::optional<unsigned> getReqNTID(const Function &F) {
296   // Note: The semantics here are a bit strange. See getMaxNTID.
297   std::optional<unsigned> ReqNTIDx = getReqNTIDx(F);
298   std::optional<unsigned> ReqNTIDy = getReqNTIDy(F);
299   std::optional<unsigned> ReqNTIDz = getReqNTIDz(F);
300   if (ReqNTIDx || ReqNTIDy || ReqNTIDz)
301     return ReqNTIDx.value_or(1) * ReqNTIDy.value_or(1) * ReqNTIDz.value_or(1);
302   return std::nullopt;
303 }
304 
305 std::optional<unsigned> getMinCTASm(const Function &F) {
306   return findOneNVVMAnnotation(&F, "minctasm");
307 }
308 
309 std::optional<unsigned> getMaxNReg(const Function &F) {
310   return findOneNVVMAnnotation(&F, "maxnreg");
311 }
312 
313 bool isKernelFunction(const Function &F) {
314   if (F.getCallingConv() == CallingConv::PTX_Kernel)
315     return true;
316 
317   if (const auto X = findOneNVVMAnnotation(&F, "kernel"))
318     return (*X == 1);
319 
320   return false;
321 }
322 
323 MaybeAlign getAlign(const Function &F, unsigned Index) {
324   // First check the alignstack metadata
325   if (MaybeAlign StackAlign =
326           F.getAttributes().getAttributes(Index).getStackAlignment())
327     return StackAlign;
328 
329   // If that is missing, check the legacy nvvm metadata
330   std::vector<unsigned> Vs;
331   bool retval = findAllNVVMAnnotation(&F, "align", Vs);
332   if (!retval)
333     return std::nullopt;
334   for (unsigned V : Vs)
335     if ((V >> 16) == Index)
336       return Align(V & 0xFFFF);
337 
338   return std::nullopt;
339 }
340 
341 MaybeAlign getAlign(const CallInst &I, unsigned Index) {
342   // First check the alignstack metadata
343   if (MaybeAlign StackAlign =
344           I.getAttributes().getAttributes(Index).getStackAlignment())
345     return StackAlign;
346 
347   // If that is missing, check the legacy nvvm metadata
348   if (MDNode *alignNode = I.getMetadata("callalign")) {
349     for (int i = 0, n = alignNode->getNumOperands(); i < n; i++) {
350       if (const ConstantInt *CI =
351               mdconst::dyn_extract<ConstantInt>(alignNode->getOperand(i))) {
352         unsigned V = CI->getZExtValue();
353         if ((V >> 16) == Index)
354           return Align(V & 0xFFFF);
355         if ((V >> 16) > Index)
356           return std::nullopt;
357       }
358     }
359   }
360   return std::nullopt;
361 }
362 
363 Function *getMaybeBitcastedCallee(const CallBase *CB) {
364   return dyn_cast<Function>(CB->getCalledOperand()->stripPointerCasts());
365 }
366 
367 bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
368   const auto &ST =
369       *static_cast<const NVPTXTargetMachine &>(TM).getSubtargetImpl();
370   if (!ST.hasNoReturn())
371     return false;
372 
373   assert((isa<Function>(V) || isa<CallInst>(V)) &&
374          "Expect either a call instruction or a function");
375 
376   if (const CallInst *CallI = dyn_cast<CallInst>(V))
377     return CallI->doesNotReturn() &&
378            CallI->getFunctionType()->getReturnType()->isVoidTy();
379 
380   const Function *F = cast<Function>(V);
381   return F->doesNotReturn() &&
382          F->getFunctionType()->getReturnType()->isVoidTy() &&
383          !isKernelFunction(*F);
384 }
385 
386 bool Isv2x16VT(EVT VT) {
387   return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
388 }
389 
390 } // namespace llvm
391