xref: /llvm-project/offload/include/OpenMP/OMPT/Interface.h (revision d36f66b42d7abec73bb5b953612eef26e6c12e0a)
1 //===-- OpenMP/OMPT/Interface.h - OpenMP Tooling interfaces ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Declarations for OpenMP Tool callback dispatchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef OFFLOAD_INCLUDE_OPENMP_OMPT_INTERFACE_H
14 #define OFFLOAD_INCLUDE_OPENMP_OMPT_INTERFACE_H
15 
16 // Only provide functionality if target OMPT support is enabled
17 #ifdef OMPT_SUPPORT
18 #include "Callback.h"
19 #include "omp-tools.h"
20 
21 #include "llvm/Support/ErrorHandling.h"
22 
23 #include <functional>
24 #include <tuple>
25 
26 #define OMPT_IF_BUILT(stmt) stmt
27 
28 /// Callbacks for target regions require task_data representing the
29 /// encountering task.
30 /// Callbacks for target regions and target data ops require
31 /// target_task_data representing the target task region.
32 typedef ompt_data_t *(*ompt_get_task_data_t)();
33 typedef ompt_data_t *(*ompt_get_target_task_data_t)();
34 
35 namespace llvm {
36 namespace omp {
37 namespace target {
38 namespace ompt {
39 
40 /// Function pointers that will be used to track task_data and
41 /// target_task_data.
42 static ompt_get_task_data_t ompt_get_task_data_fn;
43 static ompt_get_target_task_data_t ompt_get_target_task_data_fn;
44 
45 /// Used to maintain execution state for this thread
46 class Interface {
47 public:
48   /// Top-level function for invoking callback before device data allocation
49   void beginTargetDataAlloc(int64_t DeviceId, void *HstPtrBegin,
50                             void **TgtPtrBegin, size_t Size, void *Code);
51 
52   /// Top-level function for invoking callback after device data allocation
53   void endTargetDataAlloc(int64_t DeviceId, void *HstPtrBegin,
54                           void **TgtPtrBegin, size_t Size, void *Code);
55 
56   /// Top-level function for invoking callback before data submit
57   void beginTargetDataSubmit(int64_t SrcDeviceId, void *SrcPtrBegin,
58                              int64_t DstDeviceId, void *DstPtrBegin,
59                              size_t Size, void *Code);
60 
61   /// Top-level function for invoking callback after data submit
62   void endTargetDataSubmit(int64_t SrcDeviceId, void *SrcPtrBegin,
63                            int64_t DstDeviceId, void *DstPtrBegin, size_t Size,
64                            void *Code);
65 
66   /// Top-level function for invoking callback before device data deallocation
67   void beginTargetDataDelete(int64_t DeviceId, void *TgtPtrBegin, void *Code);
68 
69   /// Top-level function for invoking callback after device data deallocation
70   void endTargetDataDelete(int64_t DeviceId, void *TgtPtrBegin, void *Code);
71 
72   /// Top-level function for invoking callback before data retrieve
73   void beginTargetDataRetrieve(int64_t SrcDeviceId, void *SrcPtrBegin,
74                                int64_t DstDeviceId, void *DstPtrBegin,
75                                size_t Size, void *Code);
76 
77   /// Top-level function for invoking callback after data retrieve
78   void endTargetDataRetrieve(int64_t SrcDeviceId, void *SrcPtrBegin,
79                              int64_t DstDeviceId, void *DstPtrBegin,
80                              size_t Size, void *Code);
81 
82   /// Top-level function for invoking callback before kernel dispatch
83   void beginTargetSubmit(unsigned int NumTeams = 1);
84 
85   /// Top-level function for invoking callback after kernel dispatch
86   void endTargetSubmit(unsigned int NumTeams = 1);
87 
88   // Target region callbacks
89 
90   /// Top-level function for invoking callback before target enter data
91   /// construct
92   void beginTargetDataEnter(int64_t DeviceId, void *Code);
93 
94   /// Top-level function for invoking callback after target enter data
95   /// construct
96   void endTargetDataEnter(int64_t DeviceId, void *Code);
97 
98   /// Top-level function for invoking callback before target exit data
99   /// construct
100   void beginTargetDataExit(int64_t DeviceId, void *Code);
101 
102   /// Top-level function for invoking callback after target exit data
103   /// construct
104   void endTargetDataExit(int64_t DeviceId, void *Code);
105 
106   /// Top-level function for invoking callback before target update construct
107   void beginTargetUpdate(int64_t DeviceId, void *Code);
108 
109   /// Top-level function for invoking callback after target update construct
110   void endTargetUpdate(int64_t DeviceId, void *Code);
111 
112   /// Top-level function for invoking callback before target associate API
113   void beginTargetAssociatePointer(int64_t DeviceId, void *HstPtrBegin,
114                                    void *TgtPtrBegin, size_t Size, void *Code);
115 
116   /// Top-level function for invoking callback after target associate API
117   void endTargetAssociatePointer(int64_t DeviceId, void *HstPtrBegin,
118                                  void *TgtPtrBegin, size_t Size, void *Code);
119 
120   /// Top-level function for invoking callback before target disassociate API
121   void beginTargetDisassociatePointer(int64_t DeviceId, void *HstPtrBegin,
122                                       void *TgtPtrBegin, size_t Size,
123                                       void *Code);
124 
125   /// Top-level function for invoking callback after target disassociate API
126   void endTargetDisassociatePointer(int64_t DeviceId, void *HstPtrBegin,
127                                     void *TgtPtrBegin, size_t Size, void *Code);
128 
129   // Target kernel callbacks
130 
131   /// Top-level function for invoking callback before target construct
132   void beginTarget(int64_t DeviceId, void *Code);
133 
134   /// Top-level function for invoking callback after target construct
135   void endTarget(int64_t DeviceId, void *Code);
136 
137   // Callback getter: Target data operations
138   template <ompt_target_data_op_t OpType> auto getCallbacks() {
139     if constexpr (OpType == ompt_target_data_alloc ||
140                   OpType == ompt_target_data_alloc_async)
141       return std::make_pair(std::mem_fn(&Interface::beginTargetDataAlloc),
142                             std::mem_fn(&Interface::endTargetDataAlloc));
143 
144     if constexpr (OpType == ompt_target_data_delete ||
145                   OpType == ompt_target_data_delete_async)
146       return std::make_pair(std::mem_fn(&Interface::beginTargetDataDelete),
147                             std::mem_fn(&Interface::endTargetDataDelete));
148 
149     if constexpr (OpType == ompt_target_data_transfer_to_device ||
150                   OpType == ompt_target_data_transfer_to_device_async)
151       return std::make_pair(std::mem_fn(&Interface::beginTargetDataSubmit),
152                             std::mem_fn(&Interface::endTargetDataSubmit));
153 
154     if constexpr (OpType == ompt_target_data_transfer_from_device ||
155                   OpType == ompt_target_data_transfer_from_device_async)
156       return std::make_pair(std::mem_fn(&Interface::beginTargetDataRetrieve),
157                             std::mem_fn(&Interface::endTargetDataRetrieve));
158 
159     if constexpr (OpType == ompt_target_data_associate)
160       return std::make_pair(
161           std::mem_fn(&Interface::beginTargetAssociatePointer),
162           std::mem_fn(&Interface::endTargetAssociatePointer));
163 
164     if constexpr (OpType == ompt_target_data_disassociate)
165       return std::make_pair(
166           std::mem_fn(&Interface::beginTargetDisassociatePointer),
167           std::mem_fn(&Interface::endTargetDisassociatePointer));
168 
169     llvm_unreachable("Unhandled target data operation type!");
170   }
171 
172   // Callback getter: Target region operations
173   template <ompt_target_t OpType> auto getCallbacks() {
174     if constexpr (OpType == ompt_target_enter_data ||
175                   OpType == ompt_target_enter_data_nowait)
176       return std::make_pair(std::mem_fn(&Interface::beginTargetDataEnter),
177                             std::mem_fn(&Interface::endTargetDataEnter));
178 
179     if constexpr (OpType == ompt_target_exit_data ||
180                   OpType == ompt_target_exit_data_nowait)
181       return std::make_pair(std::mem_fn(&Interface::beginTargetDataExit),
182                             std::mem_fn(&Interface::endTargetDataExit));
183 
184     if constexpr (OpType == ompt_target_update ||
185                   OpType == ompt_target_update_nowait)
186       return std::make_pair(std::mem_fn(&Interface::beginTargetUpdate),
187                             std::mem_fn(&Interface::endTargetUpdate));
188 
189     if constexpr (OpType == ompt_target || OpType == ompt_target_nowait)
190       return std::make_pair(std::mem_fn(&Interface::beginTarget),
191                             std::mem_fn(&Interface::endTarget));
192 
193     llvm_unreachable("Unknown target region operation type!");
194   }
195 
196   // Callback getter: Kernel launch operation
197   template <ompt_callbacks_t OpType> auto getCallbacks() {
198     // We use 'ompt_callbacks_t', because no other enum is currently available
199     // to model a kernel launch / target submit operation.
200     if constexpr (OpType == ompt_callback_target_submit)
201       return std::make_pair(std::mem_fn(&Interface::beginTargetSubmit),
202                             std::mem_fn(&Interface::endTargetSubmit));
203 
204     llvm_unreachable("Unhandled target operation!");
205   }
206 
207   /// Setters for target region and target operation correlation ids
208   void setTargetDataValue(uint64_t DataValue) { TargetData.value = DataValue; }
209   void setTargetDataPtr(void *DataPtr) { TargetData.ptr = DataPtr; }
210   void setHostOpId(ompt_id_t OpId) { HostOpId = OpId; }
211 
212   /// Getters for target region and target operation correlation ids
213   uint64_t getTargetDataValue() { return TargetData.value; }
214   void *getTargetDataPtr() { return TargetData.ptr; }
215   ompt_id_t getHostOpId() { return HostOpId; }
216 
217 private:
218   /// Target operations id
219   ompt_id_t HostOpId = 0;
220 
221   /// Target region data
222   ompt_data_t TargetData = ompt_data_none;
223 
224   /// Task data representing the encountering task
225   ompt_data_t *TaskData = nullptr;
226 
227   /// Target task data representing the target task region
228   ompt_data_t *TargetTaskData = nullptr;
229 
230   /// Used for marking begin of a data operation
231   void beginTargetDataOperation();
232 
233   /// Used for marking end of a data operation
234   void endTargetDataOperation();
235 
236   /// Used for marking begin of a target region
237   void beginTargetRegion();
238 
239   /// Used for marking end of a target region
240   void endTargetRegion();
241 };
242 
243 /// Thread local state for target region and associated metadata
244 extern thread_local Interface RegionInterface;
245 
246 /// Thread local variable holding the return address.
247 /// When using __builtin_return_address to set the return address,
248 /// allow 0 as the only argument to avoid unpredictable effects.
249 extern thread_local void *ReturnAddress;
250 
251 template <typename FuncTy, typename ArgsTy, size_t... IndexSeq>
252 void InvokeInterfaceFunction(FuncTy Func, ArgsTy Args,
253                              std::index_sequence<IndexSeq...>) {
254   std::invoke(Func, RegionInterface, std::get<IndexSeq>(Args)...);
255 }
256 
257 template <typename CallbackPairTy, typename... ArgsTy> class InterfaceRAII {
258 public:
259   InterfaceRAII(CallbackPairTy Callbacks, ArgsTy... Args)
260       : Arguments(Args...), beginFunction(std::get<0>(Callbacks)),
261         endFunction(std::get<1>(Callbacks)) {
262     performIfOmptInitialized(begin());
263   }
264   ~InterfaceRAII() { performIfOmptInitialized(end()); }
265 
266 private:
267   void begin() {
268     auto IndexSequence =
269         std::make_index_sequence<std::tuple_size_v<decltype(Arguments)>>{};
270     InvokeInterfaceFunction(beginFunction, Arguments, IndexSequence);
271   }
272 
273   void end() {
274     auto IndexSequence =
275         std::make_index_sequence<std::tuple_size_v<decltype(Arguments)>>{};
276     InvokeInterfaceFunction(endFunction, Arguments, IndexSequence);
277   }
278 
279   std::tuple<ArgsTy...> Arguments;
280   typename CallbackPairTy::first_type beginFunction;
281   typename CallbackPairTy::second_type endFunction;
282 };
283 
284 // InterfaceRAII's class template argument deduction guide
285 template <typename CallbackPairTy, typename... ArgsTy>
286 InterfaceRAII(CallbackPairTy Callbacks, ArgsTy... Args)
287     -> InterfaceRAII<CallbackPairTy, ArgsTy...>;
288 
289 /// Used to set and reset the thread-local return address. The RAII is expected
290 /// to be created at a runtime entry point when the return address should be
291 /// null. If so, the return address is set and \p IsSetter is set in the ctor.
292 /// The dtor resets the return address only if the corresponding object set it.
293 /// So if the RAII is called from a nested runtime function, the ctor/dtor will
294 /// do nothing since the thread local return address is already set.
295 class ReturnAddressSetterRAII {
296 public:
297   ReturnAddressSetterRAII(void *RA) : IsSetter(false) {
298     // Handle nested calls. If already set, do not set again since it
299     // must be in a nested call.
300     if (ReturnAddress == nullptr) {
301       // Store the return address to a thread local variable.
302       ReturnAddress = RA;
303       IsSetter = true;
304     }
305   }
306   ~ReturnAddressSetterRAII() {
307     // Reset the return address if this object set it.
308     if (IsSetter)
309       ReturnAddress = nullptr;
310   }
311 
312 private:
313   // Did this object set the thread-local return address?
314   bool IsSetter;
315 };
316 
317 } // namespace ompt
318 } // namespace target
319 } // namespace omp
320 } // namespace llvm
321 
322 // The getter returns the address stored in the thread local variable.
323 #define OMPT_GET_RETURN_ADDRESS llvm::omp::target::ompt::ReturnAddress
324 
325 #else
326 #define OMPT_IF_BUILT(stmt)
327 #endif
328 
329 #endif // OFFLOAD_INCLUDE_OPENMP_OMPT_INTERFACE_H
330