xref: /llvm-project/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp (revision 335f1a72b22560e61f6170efef740c9c26b24f1a)
1 //===----- HipStdPar.cpp - HIP C++ Standard Parallelism Support Passes ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file implements two passes that enable HIP C++ Standard Parallelism
9 // Support:
10 //
11 // 1. AcceleratorCodeSelection (required): Given that only algorithms are
12 //    accelerated, and that the accelerated implementation exists in the form of
13 //    a compute kernel, we assume that only the kernel, and all functions
14 //    reachable from it, constitute code that the user expects the accelerator
15 //    to execute. Thus, we identify the set of all functions reachable from
16 //    kernels, and then remove all unreachable ones. This last part is necessary
17 //    because it is possible for code that the user did not expect to execute on
18 //    an accelerator to contain constructs that cannot be handled by the target
19 //    BE, which cannot be provably demonstrated to be dead code in general, and
20 //    thus can lead to mis-compilation. The degenerate case of this is when a
21 //    Module contains no kernels (the parent TU had no algorithm invocations fit
22 //    for acceleration), which we handle by completely emptying said module.
23 //    **NOTE**: The above does not handle indirectly reachable functions i.e.
24 //              it is possible to obtain a case where the target of an indirect
25 //              call is otherwise unreachable and thus is removed; this
26 //              restriction is aligned with the current `-hipstdpar` limitations
27 //              and will be relaxed in the future.
28 //
29 // 2. AllocationInterposition (required only when on-demand paging is
30 //    unsupported): Some accelerators or operating systems might not support
31 //    transparent on-demand paging. Thus, they would only be able to access
32 //    memory that is allocated by an accelerator-aware mechanism. For such cases
33 //    the user can opt into enabling allocation / deallocation interposition,
34 //    whereby we replace calls to known allocation / deallocation functions with
35 //    calls to runtime implemented equivalents that forward the requests to
36 //    accelerator-aware interfaces. We also support freeing system allocated
37 //    memory that ends up in one of the runtime equivalents, since this can
38 //    happen if e.g. a library that was compiled without interposition returns
39 //    an allocation that can be validly passed to `free`.
40 //===----------------------------------------------------------------------===//
41 
42 #include "llvm/Transforms/HipStdPar/HipStdPar.h"
43 
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/SmallPtrSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/Analysis/CallGraph.h"
48 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
49 #include "llvm/IR/Constants.h"
50 #include "llvm/IR/Function.h"
51 #include "llvm/IR/Module.h"
52 #include "llvm/Transforms/Utils/ModuleUtils.h"
53 
54 #include <cassert>
55 #include <string>
56 #include <utility>
57 
58 using namespace llvm;
59 
60 template<typename T>
61 static inline void eraseFromModule(T &ToErase) {
62   ToErase.replaceAllUsesWith(PoisonValue::get(ToErase.getType()));
63   ToErase.eraseFromParent();
64 }
65 
66 static inline bool checkIfSupported(GlobalVariable &G) {
67   if (!G.isThreadLocal())
68     return true;
69 
70   G.dropDroppableUses();
71 
72   if (!G.isConstantUsed())
73     return true;
74 
75   std::string W;
76   raw_string_ostream OS(W);
77 
78   OS << "Accelerator does not support the thread_local variable "
79     << G.getName();
80 
81   Instruction *I = nullptr;
82   SmallVector<User *> Tmp(G.users());
83   SmallPtrSet<User *, 5> Visited;
84   do {
85     auto U = std::move(Tmp.back());
86     Tmp.pop_back();
87 
88     if (!Visited.insert(U).second)
89       continue;
90 
91     if (isa<Instruction>(U))
92       I = cast<Instruction>(U);
93     else
94       Tmp.insert(Tmp.end(), U->user_begin(), U->user_end());
95   } while (!I && !Tmp.empty());
96 
97   assert(I && "thread_local global should have at least one non-constant use.");
98 
99   G.getContext().diagnose(
100     DiagnosticInfoUnsupported(*I->getParent()->getParent(), W,
101                               I->getDebugLoc(), DS_Error));
102 
103   return false;
104 }
105 
106 static inline void clearModule(Module &M) { // TODO: simplify.
107   while (!M.functions().empty())
108     eraseFromModule(*M.begin());
109   while (!M.globals().empty())
110     eraseFromModule(*M.globals().begin());
111   while (!M.aliases().empty())
112     eraseFromModule(*M.aliases().begin());
113   while (!M.ifuncs().empty())
114     eraseFromModule(*M.ifuncs().begin());
115 }
116 
117 static inline void maybeHandleGlobals(Module &M) {
118   unsigned GlobAS = M.getDataLayout().getDefaultGlobalsAddressSpace();
119   for (auto &&G : M.globals()) { // TODO: should we handle these in the FE?
120     if (!checkIfSupported(G))
121       return clearModule(M);
122 
123     if (G.isThreadLocal())
124       continue;
125     if (G.isConstant())
126       continue;
127     if (G.getAddressSpace() != GlobAS)
128       continue;
129     if (G.getLinkage() != GlobalVariable::ExternalLinkage)
130       continue;
131 
132     G.setLinkage(GlobalVariable::ExternalWeakLinkage);
133     G.setInitializer(nullptr);
134     G.setExternallyInitialized(true);
135   }
136 }
137 
138 template<unsigned N>
139 static inline void removeUnreachableFunctions(
140   const SmallPtrSet<const Function *, N>& Reachable, Module &M) {
141   removeFromUsedLists(M, [&](Constant *C) {
142     if (auto F = dyn_cast<Function>(C))
143       return !Reachable.contains(F);
144 
145     return false;
146   });
147 
148   SmallVector<std::reference_wrapper<Function>> ToRemove;
149   copy_if(M, std::back_inserter(ToRemove), [&](auto &&F) {
150     return !F.isIntrinsic() && !Reachable.contains(&F);
151   });
152 
153   for_each(ToRemove, eraseFromModule<Function>);
154 }
155 
156 static inline bool isAcceleratorExecutionRoot(const Function *F) {
157     if (!F)
158       return false;
159 
160     return F->getCallingConv() == CallingConv::AMDGPU_KERNEL;
161 }
162 
163 static inline bool checkIfSupported(const Function *F, const CallBase *CB) {
164   const auto Dx = F->getName().rfind("__hipstdpar_unsupported");
165 
166   if (Dx == StringRef::npos)
167     return true;
168 
169   const auto N = F->getName().substr(0, Dx);
170 
171   std::string W;
172   raw_string_ostream OS(W);
173 
174   if (N == "__ASM")
175     OS << "Accelerator does not support the ASM block:\n"
176       << cast<ConstantDataArray>(CB->getArgOperand(0))->getAsCString();
177   else
178     OS << "Accelerator does not support the " << N << " function.";
179 
180   auto Caller = CB->getParent()->getParent();
181 
182   Caller->getContext().diagnose(
183     DiagnosticInfoUnsupported(*Caller, W, CB->getDebugLoc(), DS_Error));
184 
185   return false;
186 }
187 
188 PreservedAnalyses
189   HipStdParAcceleratorCodeSelectionPass::run(Module &M,
190                                              ModuleAnalysisManager &MAM) {
191   auto &CGA = MAM.getResult<CallGraphAnalysis>(M);
192 
193   SmallPtrSet<const Function *, 32> Reachable;
194   for (auto &&CGN : CGA) {
195     if (!isAcceleratorExecutionRoot(CGN.first))
196       continue;
197 
198     Reachable.insert(CGN.first);
199 
200     SmallVector<const Function *> Tmp({CGN.first});
201     do {
202       auto F = std::move(Tmp.back());
203       Tmp.pop_back();
204 
205       for (auto &&N : *CGA[F]) {
206         if (!N.second)
207           continue;
208         if (!N.second->getFunction())
209           continue;
210         if (Reachable.contains(N.second->getFunction()))
211           continue;
212 
213         if (!checkIfSupported(N.second->getFunction(),
214                               dyn_cast<CallBase>(*N.first)))
215           return PreservedAnalyses::none();
216 
217         Reachable.insert(N.second->getFunction());
218         Tmp.push_back(N.second->getFunction());
219       }
220     } while (!std::empty(Tmp));
221   }
222 
223   if (std::empty(Reachable))
224     clearModule(M);
225   else
226     removeUnreachableFunctions(Reachable, M);
227 
228   maybeHandleGlobals(M);
229 
230   return PreservedAnalyses::none();
231 }
232 
233 static constexpr std::pair<StringLiteral, StringLiteral> ReplaceMap[]{
234   {"aligned_alloc",             "__hipstdpar_aligned_alloc"},
235   {"calloc",                    "__hipstdpar_calloc"},
236   {"free",                      "__hipstdpar_free"},
237   {"malloc",                    "__hipstdpar_malloc"},
238   {"memalign",                  "__hipstdpar_aligned_alloc"},
239   {"posix_memalign",            "__hipstdpar_posix_aligned_alloc"},
240   {"realloc",                   "__hipstdpar_realloc"},
241   {"reallocarray",              "__hipstdpar_realloc_array"},
242   {"_ZdaPv",                    "__hipstdpar_operator_delete"},
243   {"_ZdaPvm",                   "__hipstdpar_operator_delete_sized"},
244   {"_ZdaPvSt11align_val_t",     "__hipstdpar_operator_delete_aligned"},
245   {"_ZdaPvmSt11align_val_t",    "__hipstdpar_operator_delete_aligned_sized"},
246   {"_ZdlPv",                    "__hipstdpar_operator_delete"},
247   {"_ZdlPvm",                   "__hipstdpar_operator_delete_sized"},
248   {"_ZdlPvSt11align_val_t",     "__hipstdpar_operator_delete_aligned"},
249   {"_ZdlPvmSt11align_val_t",    "__hipstdpar_operator_delete_aligned_sized"},
250   {"_Znam",                     "__hipstdpar_operator_new"},
251   {"_ZnamRKSt9nothrow_t",       "__hipstdpar_operator_new_nothrow"},
252   {"_ZnamSt11align_val_t",      "__hipstdpar_operator_new_aligned"},
253   {"_ZnamSt11align_val_tRKSt9nothrow_t",
254                                 "__hipstdpar_operator_new_aligned_nothrow"},
255 
256   {"_Znwm",                     "__hipstdpar_operator_new"},
257   {"_ZnwmRKSt9nothrow_t",       "__hipstdpar_operator_new_nothrow"},
258   {"_ZnwmSt11align_val_t",      "__hipstdpar_operator_new_aligned"},
259   {"_ZnwmSt11align_val_tRKSt9nothrow_t",
260                                 "__hipstdpar_operator_new_aligned_nothrow"},
261   {"__builtin_calloc",          "__hipstdpar_calloc"},
262   {"__builtin_free",            "__hipstdpar_free"},
263   {"__builtin_malloc",          "__hipstdpar_malloc"},
264   {"__builtin_operator_delete", "__hipstdpar_operator_delete"},
265   {"__builtin_operator_new",    "__hipstdpar_operator_new"},
266   {"__builtin_realloc",         "__hipstdpar_realloc"},
267   {"__libc_calloc",             "__hipstdpar_calloc"},
268   {"__libc_free",               "__hipstdpar_free"},
269   {"__libc_malloc",             "__hipstdpar_malloc"},
270   {"__libc_memalign",           "__hipstdpar_aligned_alloc"},
271   {"__libc_realloc",            "__hipstdpar_realloc"}
272 };
273 
274 PreservedAnalyses
275 HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
276   SmallDenseMap<StringRef, StringRef> AllocReplacements(std::cbegin(ReplaceMap),
277                                                         std::cend(ReplaceMap));
278 
279   for (auto &&F : M) {
280     if (!F.hasName())
281       continue;
282     auto It = AllocReplacements.find(F.getName());
283     if (It == AllocReplacements.end())
284       continue;
285 
286     if (auto R = M.getFunction(It->second)) {
287       F.replaceAllUsesWith(R);
288     } else {
289       std::string W;
290       raw_string_ostream OS(W);
291 
292       OS << "cannot be interposed, missing: " << AllocReplacements[F.getName()]
293         << ". Tried to run the allocation interposition pass without the "
294         << "replacement functions available.";
295 
296       F.getContext().diagnose(DiagnosticInfoUnsupported(F, W,
297                                                         F.getSubprogram(),
298                                                         DS_Warning));
299     }
300   }
301 
302   if (auto F = M.getFunction("__hipstdpar_hidden_malloc")) {
303     auto LibcMalloc = M.getOrInsertFunction(
304         "__libc_malloc", F->getFunctionType(), F->getAttributes());
305     F->replaceAllUsesWith(LibcMalloc.getCallee());
306 
307     eraseFromModule(*F);
308   }
309   if (auto F = M.getFunction("__hipstdpar_hidden_free")) {
310     auto LibcFree = M.getOrInsertFunction("__libc_free", F->getFunctionType(),
311                                           F->getAttributes());
312     F->replaceAllUsesWith(LibcFree.getCallee());
313 
314     eraseFromModule(*F);
315   }
316 
317   return PreservedAnalyses::none();
318 }
319