xref: /llvm-project/offload/plugins-nextgen/common/src/RPC.cpp (revision e7592d83e0ac58f61cfe8dcf61bcc8e7a8bd67b3)
1 //===- RPC.h - Interface for remote procedure calls from the GPU ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "RPC.h"
10 
11 #include "Shared/Debug.h"
12 #include "Shared/RPCOpcodes.h"
13 
14 #include "PluginInterface.h"
15 
16 #include "shared/rpc.h"
17 #include "shared/rpc_opcodes.h"
18 
19 using namespace llvm;
20 using namespace omp;
21 using namespace target;
22 
23 template <uint32_t NumLanes>
24 rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device,
25                                  rpc::Server::Port &Port) {
26 
27   switch (Port.get_opcode()) {
28   case LIBC_MALLOC: {
29     Port.recv_and_send([&](rpc::Buffer *Buffer, uint32_t) {
30       Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate(
31           Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING));
32     });
33     break;
34   }
35   case LIBC_FREE: {
36     Port.recv([&](rpc::Buffer *Buffer, uint32_t) {
37       Device.free(reinterpret_cast<void *>(Buffer->data[0]),
38                   TARGET_ALLOC_DEVICE_NON_BLOCKING);
39     });
40     break;
41   }
42   case OFFLOAD_HOST_CALL: {
43     uint64_t Sizes[NumLanes] = {0};
44     unsigned long long Results[NumLanes] = {0};
45     void *Args[NumLanes] = {nullptr};
46     Port.recv_n(Args, Sizes, [&](uint64_t Size) { return new char[Size]; });
47     Port.recv([&](rpc::Buffer *buffer, uint32_t ID) {
48       using FuncPtrTy = unsigned long long (*)(void *);
49       auto Func = reinterpret_cast<FuncPtrTy>(buffer->data[0]);
50       Results[ID] = Func(Args[ID]);
51     });
52     Port.send([&](rpc::Buffer *Buffer, uint32_t ID) {
53       Buffer->data[0] = static_cast<uint64_t>(Results[ID]);
54       delete[] reinterpret_cast<char *>(Args[ID]);
55     });
56     break;
57   }
58   default:
59     return rpc::RPC_UNHANDLED_OPCODE;
60     break;
61   }
62   return rpc::RPC_SUCCESS;
63 }
64 
65 static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device,
66                                         rpc::Server::Port &Port,
67                                         uint32_t NumLanes) {
68   if (NumLanes == 1)
69     return handleOffloadOpcodes<1>(Device, Port);
70   else if (NumLanes == 32)
71     return handleOffloadOpcodes<32>(Device, Port);
72   else if (NumLanes == 64)
73     return handleOffloadOpcodes<64>(Device, Port);
74   else
75     return rpc::RPC_ERROR;
76 }
77 
78 static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) {
79   uint64_t NumPorts =
80       std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
81   rpc::Server Server(NumPorts, Buffer);
82 
83   auto Port = Server.try_open(Device.getWarpSize());
84   if (!Port)
85     return rpc::RPC_SUCCESS;
86 
87   rpc::Status Status =
88       handleOffloadOpcodes(Device, *Port, Device.getWarpSize());
89 
90   // Let the `libc` library handle any other unhandled opcodes.
91 #ifdef LIBOMPTARGET_RPC_SUPPORT
92   if (Status == rpc::RPC_UNHANDLED_OPCODE)
93     Status = handle_libc_opcodes(*Port, Device.getWarpSize());
94 #endif
95 
96   Port->close();
97 
98   return Status;
99 }
100 
101 void RPCServerTy::ServerThread::startThread() {
102   assert(!Running.load(std::memory_order_relaxed) &&
103          "Attempting to start thread that is already running");
104   Running.store(true, std::memory_order_release);
105   Worker = std::thread([this]() { run(); });
106 }
107 
108 void RPCServerTy::ServerThread::shutDown() {
109   assert(Running.load(std::memory_order_relaxed) &&
110          "Attempting to shut down a thread that is not running");
111   {
112     std::lock_guard<decltype(Mutex)> Lock(Mutex);
113     Running.store(false, std::memory_order_release);
114     CV.notify_all();
115   }
116   if (Worker.joinable())
117     Worker.join();
118 }
119 
120 void RPCServerTy::ServerThread::run() {
121   std::unique_lock<decltype(Mutex)> Lock(Mutex);
122   for (;;) {
123     CV.wait(Lock, [&]() {
124       return NumUsers.load(std::memory_order_acquire) > 0 ||
125              !Running.load(std::memory_order_acquire);
126     });
127 
128     if (!Running.load(std::memory_order_acquire))
129       return;
130 
131     Lock.unlock();
132     while (NumUsers.load(std::memory_order_relaxed) > 0 &&
133            Running.load(std::memory_order_relaxed)) {
134       for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) {
135         if (!Buffer || !Device)
136           continue;
137 
138         // If running the server failed, print a message but keep running.
139         if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS)
140           FAILURE_MESSAGE("Unhandled or invalid RPC opcode!");
141       }
142     }
143     Lock.lock();
144   }
145 }
146 
147 RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
148     : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())),
149       Devices(std::make_unique<plugin::GenericDeviceTy *[]>(
150           Plugin.getNumDevices())),
151       Thread(new ServerThread(Buffers.get(), Devices.get(),
152                               Plugin.getNumDevices())) {}
153 
154 llvm::Error RPCServerTy::startThread() {
155   Thread->startThread();
156   return Error::success();
157 }
158 
159 llvm::Error RPCServerTy::shutDown() {
160   Thread->shutDown();
161   return Error::success();
162 }
163 
164 llvm::Expected<bool>
165 RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
166                               plugin::GenericGlobalHandlerTy &Handler,
167                               plugin::DeviceImageTy &Image) {
168   return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client");
169 }
170 
171 Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
172                               plugin::GenericGlobalHandlerTy &Handler,
173                               plugin::DeviceImageTy &Image) {
174   uint64_t NumPorts =
175       std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
176   void *RPCBuffer = Device.allocate(
177       rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr,
178       TARGET_ALLOC_HOST);
179   if (!RPCBuffer)
180     return plugin::Plugin::error(
181         "Failed to initialize RPC server for device %d", Device.getDeviceId());
182 
183   // Get the address of the RPC client from the device.
184   plugin::GlobalTy ClientGlobal("__llvm_rpc_client", sizeof(rpc::Client));
185   if (auto Err =
186           Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal))
187     return Err;
188 
189   rpc::Client client(NumPorts, RPCBuffer);
190   if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client,
191                                    sizeof(rpc::Client), nullptr))
192     return Err;
193   Buffers[Device.getDeviceId()] = RPCBuffer;
194   Devices[Device.getDeviceId()] = &Device;
195 
196   return Error::success();
197 }
198 
199 Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
200   Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
201   Buffers[Device.getDeviceId()] = nullptr;
202   Devices[Device.getDeviceId()] = nullptr;
203   return Error::success();
204 }
205