xref: /llvm-project/offload/plugins-nextgen/common/src/RPC.cpp (revision e7592d83e0ac58f61cfe8dcf61bcc8e7a8bd67b3)
1330d8983SJohannes Doerfert //===- RPC.h - Interface for remote procedure calls from the GPU ----------===//
2330d8983SJohannes Doerfert //
3330d8983SJohannes Doerfert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4330d8983SJohannes Doerfert // See https://llvm.org/LICENSE.txt for license information.
5330d8983SJohannes Doerfert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6330d8983SJohannes Doerfert //
7330d8983SJohannes Doerfert //===----------------------------------------------------------------------===//
8330d8983SJohannes Doerfert 
9330d8983SJohannes Doerfert #include "RPC.h"
10330d8983SJohannes Doerfert 
11330d8983SJohannes Doerfert #include "Shared/Debug.h"
1291f5f974SJoseph Huber #include "Shared/RPCOpcodes.h"
13330d8983SJohannes Doerfert 
14330d8983SJohannes Doerfert #include "PluginInterface.h"
15330d8983SJohannes Doerfert 
16b4d49fb5SJoseph Huber #include "shared/rpc.h"
17d7c20a6fSJoseph Huber #include "shared/rpc_opcodes.h"
18330d8983SJohannes Doerfert 
19330d8983SJohannes Doerfert using namespace llvm;
20330d8983SJohannes Doerfert using namespace omp;
21330d8983SJohannes Doerfert using namespace target;
22330d8983SJohannes Doerfert 
2391f5f974SJoseph Huber template <uint32_t NumLanes>
24134401deSJoseph Huber rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device,
2591f5f974SJoseph Huber                                  rpc::Server::Port &Port) {
2691f5f974SJoseph Huber 
2791f5f974SJoseph Huber   switch (Port.get_opcode()) {
28a6ef0debSJoseph Huber   case LIBC_MALLOC: {
2991f5f974SJoseph Huber     Port.recv_and_send([&](rpc::Buffer *Buffer, uint32_t) {
3091f5f974SJoseph Huber       Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate(
3191f5f974SJoseph Huber           Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING));
3291f5f974SJoseph Huber     });
3391f5f974SJoseph Huber     break;
3491f5f974SJoseph Huber   }
35a6ef0debSJoseph Huber   case LIBC_FREE: {
3691f5f974SJoseph Huber     Port.recv([&](rpc::Buffer *Buffer, uint32_t) {
3791f5f974SJoseph Huber       Device.free(reinterpret_cast<void *>(Buffer->data[0]),
3891f5f974SJoseph Huber                   TARGET_ALLOC_DEVICE_NON_BLOCKING);
3991f5f974SJoseph Huber     });
4091f5f974SJoseph Huber     break;
4191f5f974SJoseph Huber   }
4291f5f974SJoseph Huber   case OFFLOAD_HOST_CALL: {
4391f5f974SJoseph Huber     uint64_t Sizes[NumLanes] = {0};
4491f5f974SJoseph Huber     unsigned long long Results[NumLanes] = {0};
4591f5f974SJoseph Huber     void *Args[NumLanes] = {nullptr};
4691f5f974SJoseph Huber     Port.recv_n(Args, Sizes, [&](uint64_t Size) { return new char[Size]; });
4791f5f974SJoseph Huber     Port.recv([&](rpc::Buffer *buffer, uint32_t ID) {
4891f5f974SJoseph Huber       using FuncPtrTy = unsigned long long (*)(void *);
4991f5f974SJoseph Huber       auto Func = reinterpret_cast<FuncPtrTy>(buffer->data[0]);
5091f5f974SJoseph Huber       Results[ID] = Func(Args[ID]);
5191f5f974SJoseph Huber     });
5291f5f974SJoseph Huber     Port.send([&](rpc::Buffer *Buffer, uint32_t ID) {
5391f5f974SJoseph Huber       Buffer->data[0] = static_cast<uint64_t>(Results[ID]);
5491f5f974SJoseph Huber       delete[] reinterpret_cast<char *>(Args[ID]);
5591f5f974SJoseph Huber     });
5691f5f974SJoseph Huber     break;
5791f5f974SJoseph Huber   }
5891f5f974SJoseph Huber   default:
59e85a9f55SJinsong Ji     return rpc::RPC_UNHANDLED_OPCODE;
6091f5f974SJoseph Huber     break;
6191f5f974SJoseph Huber   }
62e85a9f55SJinsong Ji   return rpc::RPC_SUCCESS;
6391f5f974SJoseph Huber }
6491f5f974SJoseph Huber 
65134401deSJoseph Huber static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device,
6691f5f974SJoseph Huber                                         rpc::Server::Port &Port,
6791f5f974SJoseph Huber                                         uint32_t NumLanes) {
6891f5f974SJoseph Huber   if (NumLanes == 1)
69134401deSJoseph Huber     return handleOffloadOpcodes<1>(Device, Port);
7091f5f974SJoseph Huber   else if (NumLanes == 32)
71134401deSJoseph Huber     return handleOffloadOpcodes<32>(Device, Port);
7291f5f974SJoseph Huber   else if (NumLanes == 64)
73134401deSJoseph Huber     return handleOffloadOpcodes<64>(Device, Port);
7491f5f974SJoseph Huber   else
75e85a9f55SJinsong Ji     return rpc::RPC_ERROR;
7691f5f974SJoseph Huber }
7791f5f974SJoseph Huber 
78134401deSJoseph Huber static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) {
79134401deSJoseph Huber   uint64_t NumPorts =
80134401deSJoseph Huber       std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
81134401deSJoseph Huber   rpc::Server Server(NumPorts, Buffer);
82134401deSJoseph Huber 
83134401deSJoseph Huber   auto Port = Server.try_open(Device.getWarpSize());
84134401deSJoseph Huber   if (!Port)
85134401deSJoseph Huber     return rpc::RPC_SUCCESS;
86134401deSJoseph Huber 
87134401deSJoseph Huber   rpc::Status Status =
88134401deSJoseph Huber       handleOffloadOpcodes(Device, *Port, Device.getWarpSize());
89134401deSJoseph Huber 
90134401deSJoseph Huber   // Let the `libc` library handle any other unhandled opcodes.
91134401deSJoseph Huber #ifdef LIBOMPTARGET_RPC_SUPPORT
92134401deSJoseph Huber   if (Status == rpc::RPC_UNHANDLED_OPCODE)
93134401deSJoseph Huber     Status = handle_libc_opcodes(*Port, Device.getWarpSize());
94134401deSJoseph Huber #endif
95134401deSJoseph Huber 
96134401deSJoseph Huber   Port->close();
97134401deSJoseph Huber 
98134401deSJoseph Huber   return Status;
99134401deSJoseph Huber }
100134401deSJoseph Huber 
101134401deSJoseph Huber void RPCServerTy::ServerThread::startThread() {
102*e7592d83SJoseph Huber   assert(!Running.load(std::memory_order_relaxed) &&
103*e7592d83SJoseph Huber          "Attempting to start thread that is already running");
104*e7592d83SJoseph Huber   Running.store(true, std::memory_order_release);
105134401deSJoseph Huber   Worker = std::thread([this]() { run(); });
106134401deSJoseph Huber }
107134401deSJoseph Huber 
108134401deSJoseph Huber void RPCServerTy::ServerThread::shutDown() {
109*e7592d83SJoseph Huber   assert(Running.load(std::memory_order_relaxed) &&
110*e7592d83SJoseph Huber          "Attempting to shut down a thread that is not running");
111134401deSJoseph Huber   {
112134401deSJoseph Huber     std::lock_guard<decltype(Mutex)> Lock(Mutex);
113134401deSJoseph Huber     Running.store(false, std::memory_order_release);
114134401deSJoseph Huber     CV.notify_all();
115134401deSJoseph Huber   }
116134401deSJoseph Huber   if (Worker.joinable())
117134401deSJoseph Huber     Worker.join();
118134401deSJoseph Huber }
119134401deSJoseph Huber 
120134401deSJoseph Huber void RPCServerTy::ServerThread::run() {
121134401deSJoseph Huber   std::unique_lock<decltype(Mutex)> Lock(Mutex);
122134401deSJoseph Huber   for (;;) {
123134401deSJoseph Huber     CV.wait(Lock, [&]() {
124134401deSJoseph Huber       return NumUsers.load(std::memory_order_acquire) > 0 ||
125134401deSJoseph Huber              !Running.load(std::memory_order_acquire);
126134401deSJoseph Huber     });
127134401deSJoseph Huber 
128134401deSJoseph Huber     if (!Running.load(std::memory_order_acquire))
129134401deSJoseph Huber       return;
130134401deSJoseph Huber 
131134401deSJoseph Huber     Lock.unlock();
132134401deSJoseph Huber     while (NumUsers.load(std::memory_order_relaxed) > 0 &&
133134401deSJoseph Huber            Running.load(std::memory_order_relaxed)) {
134134401deSJoseph Huber       for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) {
135134401deSJoseph Huber         if (!Buffer || !Device)
136134401deSJoseph Huber           continue;
137134401deSJoseph Huber 
138134401deSJoseph Huber         // If running the server failed, print a message but keep running.
139134401deSJoseph Huber         if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS)
140134401deSJoseph Huber           FAILURE_MESSAGE("Unhandled or invalid RPC opcode!");
141134401deSJoseph Huber       }
142134401deSJoseph Huber     }
143134401deSJoseph Huber     Lock.lock();
144134401deSJoseph Huber   }
145134401deSJoseph Huber }
146134401deSJoseph Huber 
147330d8983SJohannes Doerfert RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
148134401deSJoseph Huber     : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())),
149134401deSJoseph Huber       Devices(std::make_unique<plugin::GenericDeviceTy *[]>(
150134401deSJoseph Huber           Plugin.getNumDevices())),
151134401deSJoseph Huber       Thread(new ServerThread(Buffers.get(), Devices.get(),
152134401deSJoseph Huber                               Plugin.getNumDevices())) {}
153134401deSJoseph Huber 
154134401deSJoseph Huber llvm::Error RPCServerTy::startThread() {
155134401deSJoseph Huber   Thread->startThread();
156134401deSJoseph Huber   return Error::success();
157134401deSJoseph Huber }
158134401deSJoseph Huber 
159134401deSJoseph Huber llvm::Error RPCServerTy::shutDown() {
160134401deSJoseph Huber   Thread->shutDown();
161134401deSJoseph Huber   return Error::success();
162134401deSJoseph Huber }
163330d8983SJohannes Doerfert 
164330d8983SJohannes Doerfert llvm::Expected<bool>
165330d8983SJohannes Doerfert RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
166330d8983SJohannes Doerfert                               plugin::GenericGlobalHandlerTy &Handler,
167330d8983SJohannes Doerfert                               plugin::DeviceImageTy &Image) {
16889d8e700SJoseph Huber   return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client");
169330d8983SJohannes Doerfert }
170330d8983SJohannes Doerfert 
171330d8983SJohannes Doerfert Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
172330d8983SJohannes Doerfert                               plugin::GenericGlobalHandlerTy &Handler,
173330d8983SJohannes Doerfert                               plugin::DeviceImageTy &Image) {
174330d8983SJohannes Doerfert   uint64_t NumPorts =
175b4d49fb5SJoseph Huber       std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
176b4d49fb5SJoseph Huber   void *RPCBuffer = Device.allocate(
177b4d49fb5SJoseph Huber       rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr,
178b4d49fb5SJoseph Huber       TARGET_ALLOC_HOST);
179b4d49fb5SJoseph Huber   if (!RPCBuffer)
180330d8983SJohannes Doerfert     return plugin::Plugin::error(
181b4d49fb5SJoseph Huber         "Failed to initialize RPC server for device %d", Device.getDeviceId());
182330d8983SJohannes Doerfert 
183330d8983SJohannes Doerfert   // Get the address of the RPC client from the device.
18489d8e700SJoseph Huber   plugin::GlobalTy ClientGlobal("__llvm_rpc_client", sizeof(rpc::Client));
185330d8983SJohannes Doerfert   if (auto Err =
186330d8983SJohannes Doerfert           Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal))
187330d8983SJohannes Doerfert     return Err;
188330d8983SJohannes Doerfert 
189b4d49fb5SJoseph Huber   rpc::Client client(NumPorts, RPCBuffer);
19089d8e700SJoseph Huber   if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client,
19189d8e700SJoseph Huber                                    sizeof(rpc::Client), nullptr))
192330d8983SJohannes Doerfert     return Err;
193b4d49fb5SJoseph Huber   Buffers[Device.getDeviceId()] = RPCBuffer;
194134401deSJoseph Huber   Devices[Device.getDeviceId()] = &Device;
195b4d49fb5SJoseph Huber 
196b4d49fb5SJoseph Huber   return Error::success();
197330d8983SJohannes Doerfert }
198330d8983SJohannes Doerfert 
199330d8983SJohannes Doerfert Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
200b4d49fb5SJoseph Huber   Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
201134401deSJoseph Huber   Buffers[Device.getDeviceId()] = nullptr;
202134401deSJoseph Huber   Devices[Device.getDeviceId()] = nullptr;
203b4d49fb5SJoseph Huber   return Error::success();
204330d8983SJohannes Doerfert }
205