1330d8983SJohannes Doerfert //===- RPC.h - Interface for remote procedure calls from the GPU ----------===// 2330d8983SJohannes Doerfert // 3330d8983SJohannes Doerfert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4330d8983SJohannes Doerfert // See https://llvm.org/LICENSE.txt for license information. 5330d8983SJohannes Doerfert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6330d8983SJohannes Doerfert // 7330d8983SJohannes Doerfert //===----------------------------------------------------------------------===// 8330d8983SJohannes Doerfert 9330d8983SJohannes Doerfert #include "RPC.h" 10330d8983SJohannes Doerfert 11330d8983SJohannes Doerfert #include "Shared/Debug.h" 1291f5f974SJoseph Huber #include "Shared/RPCOpcodes.h" 13330d8983SJohannes Doerfert 14330d8983SJohannes Doerfert #include "PluginInterface.h" 15330d8983SJohannes Doerfert 16b4d49fb5SJoseph Huber #include "shared/rpc.h" 17d7c20a6fSJoseph Huber #include "shared/rpc_opcodes.h" 18330d8983SJohannes Doerfert 19330d8983SJohannes Doerfert using namespace llvm; 20330d8983SJohannes Doerfert using namespace omp; 21330d8983SJohannes Doerfert using namespace target; 22330d8983SJohannes Doerfert 2391f5f974SJoseph Huber template <uint32_t NumLanes> 24134401deSJoseph Huber rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, 2591f5f974SJoseph Huber rpc::Server::Port &Port) { 2691f5f974SJoseph Huber 2791f5f974SJoseph Huber switch (Port.get_opcode()) { 28a6ef0debSJoseph Huber case LIBC_MALLOC: { 2991f5f974SJoseph Huber Port.recv_and_send([&](rpc::Buffer *Buffer, uint32_t) { 3091f5f974SJoseph Huber Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate( 3191f5f974SJoseph Huber Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING)); 3291f5f974SJoseph Huber }); 3391f5f974SJoseph Huber break; 3491f5f974SJoseph Huber } 35a6ef0debSJoseph Huber case LIBC_FREE: { 3691f5f974SJoseph Huber Port.recv([&](rpc::Buffer *Buffer, uint32_t) { 3791f5f974SJoseph Huber Device.free(reinterpret_cast<void *>(Buffer->data[0]), 3891f5f974SJoseph Huber TARGET_ALLOC_DEVICE_NON_BLOCKING); 3991f5f974SJoseph Huber }); 4091f5f974SJoseph Huber break; 4191f5f974SJoseph Huber } 4291f5f974SJoseph Huber case OFFLOAD_HOST_CALL: { 4391f5f974SJoseph Huber uint64_t Sizes[NumLanes] = {0}; 4491f5f974SJoseph Huber unsigned long long Results[NumLanes] = {0}; 4591f5f974SJoseph Huber void *Args[NumLanes] = {nullptr}; 4691f5f974SJoseph Huber Port.recv_n(Args, Sizes, [&](uint64_t Size) { return new char[Size]; }); 4791f5f974SJoseph Huber Port.recv([&](rpc::Buffer *buffer, uint32_t ID) { 4891f5f974SJoseph Huber using FuncPtrTy = unsigned long long (*)(void *); 4991f5f974SJoseph Huber auto Func = reinterpret_cast<FuncPtrTy>(buffer->data[0]); 5091f5f974SJoseph Huber Results[ID] = Func(Args[ID]); 5191f5f974SJoseph Huber }); 5291f5f974SJoseph Huber Port.send([&](rpc::Buffer *Buffer, uint32_t ID) { 5391f5f974SJoseph Huber Buffer->data[0] = static_cast<uint64_t>(Results[ID]); 5491f5f974SJoseph Huber delete[] reinterpret_cast<char *>(Args[ID]); 5591f5f974SJoseph Huber }); 5691f5f974SJoseph Huber break; 5791f5f974SJoseph Huber } 5891f5f974SJoseph Huber default: 59e85a9f55SJinsong Ji return rpc::RPC_UNHANDLED_OPCODE; 6091f5f974SJoseph Huber break; 6191f5f974SJoseph Huber } 62e85a9f55SJinsong Ji return rpc::RPC_SUCCESS; 6391f5f974SJoseph Huber } 6491f5f974SJoseph Huber 65134401deSJoseph Huber static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, 6691f5f974SJoseph Huber rpc::Server::Port &Port, 6791f5f974SJoseph Huber uint32_t NumLanes) { 6891f5f974SJoseph Huber if (NumLanes == 1) 69134401deSJoseph Huber return handleOffloadOpcodes<1>(Device, Port); 7091f5f974SJoseph Huber else if (NumLanes == 32) 71134401deSJoseph Huber return handleOffloadOpcodes<32>(Device, Port); 7291f5f974SJoseph Huber else if (NumLanes == 64) 73134401deSJoseph Huber return handleOffloadOpcodes<64>(Device, Port); 7491f5f974SJoseph Huber else 75e85a9f55SJinsong Ji return rpc::RPC_ERROR; 7691f5f974SJoseph Huber } 7791f5f974SJoseph Huber 78134401deSJoseph Huber static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) { 79134401deSJoseph Huber uint64_t NumPorts = 80134401deSJoseph Huber std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); 81134401deSJoseph Huber rpc::Server Server(NumPorts, Buffer); 82134401deSJoseph Huber 83134401deSJoseph Huber auto Port = Server.try_open(Device.getWarpSize()); 84134401deSJoseph Huber if (!Port) 85134401deSJoseph Huber return rpc::RPC_SUCCESS; 86134401deSJoseph Huber 87134401deSJoseph Huber rpc::Status Status = 88134401deSJoseph Huber handleOffloadOpcodes(Device, *Port, Device.getWarpSize()); 89134401deSJoseph Huber 90134401deSJoseph Huber // Let the `libc` library handle any other unhandled opcodes. 91134401deSJoseph Huber #ifdef LIBOMPTARGET_RPC_SUPPORT 92134401deSJoseph Huber if (Status == rpc::RPC_UNHANDLED_OPCODE) 93134401deSJoseph Huber Status = handle_libc_opcodes(*Port, Device.getWarpSize()); 94134401deSJoseph Huber #endif 95134401deSJoseph Huber 96134401deSJoseph Huber Port->close(); 97134401deSJoseph Huber 98134401deSJoseph Huber return Status; 99134401deSJoseph Huber } 100134401deSJoseph Huber 101134401deSJoseph Huber void RPCServerTy::ServerThread::startThread() { 102*e7592d83SJoseph Huber assert(!Running.load(std::memory_order_relaxed) && 103*e7592d83SJoseph Huber "Attempting to start thread that is already running"); 104*e7592d83SJoseph Huber Running.store(true, std::memory_order_release); 105134401deSJoseph Huber Worker = std::thread([this]() { run(); }); 106134401deSJoseph Huber } 107134401deSJoseph Huber 108134401deSJoseph Huber void RPCServerTy::ServerThread::shutDown() { 109*e7592d83SJoseph Huber assert(Running.load(std::memory_order_relaxed) && 110*e7592d83SJoseph Huber "Attempting to shut down a thread that is not running"); 111134401deSJoseph Huber { 112134401deSJoseph Huber std::lock_guard<decltype(Mutex)> Lock(Mutex); 113134401deSJoseph Huber Running.store(false, std::memory_order_release); 114134401deSJoseph Huber CV.notify_all(); 115134401deSJoseph Huber } 116134401deSJoseph Huber if (Worker.joinable()) 117134401deSJoseph Huber Worker.join(); 118134401deSJoseph Huber } 119134401deSJoseph Huber 120134401deSJoseph Huber void RPCServerTy::ServerThread::run() { 121134401deSJoseph Huber std::unique_lock<decltype(Mutex)> Lock(Mutex); 122134401deSJoseph Huber for (;;) { 123134401deSJoseph Huber CV.wait(Lock, [&]() { 124134401deSJoseph Huber return NumUsers.load(std::memory_order_acquire) > 0 || 125134401deSJoseph Huber !Running.load(std::memory_order_acquire); 126134401deSJoseph Huber }); 127134401deSJoseph Huber 128134401deSJoseph Huber if (!Running.load(std::memory_order_acquire)) 129134401deSJoseph Huber return; 130134401deSJoseph Huber 131134401deSJoseph Huber Lock.unlock(); 132134401deSJoseph Huber while (NumUsers.load(std::memory_order_relaxed) > 0 && 133134401deSJoseph Huber Running.load(std::memory_order_relaxed)) { 134134401deSJoseph Huber for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) { 135134401deSJoseph Huber if (!Buffer || !Device) 136134401deSJoseph Huber continue; 137134401deSJoseph Huber 138134401deSJoseph Huber // If running the server failed, print a message but keep running. 139134401deSJoseph Huber if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS) 140134401deSJoseph Huber FAILURE_MESSAGE("Unhandled or invalid RPC opcode!"); 141134401deSJoseph Huber } 142134401deSJoseph Huber } 143134401deSJoseph Huber Lock.lock(); 144134401deSJoseph Huber } 145134401deSJoseph Huber } 146134401deSJoseph Huber 147330d8983SJohannes Doerfert RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin) 148134401deSJoseph Huber : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())), 149134401deSJoseph Huber Devices(std::make_unique<plugin::GenericDeviceTy *[]>( 150134401deSJoseph Huber Plugin.getNumDevices())), 151134401deSJoseph Huber Thread(new ServerThread(Buffers.get(), Devices.get(), 152134401deSJoseph Huber Plugin.getNumDevices())) {} 153134401deSJoseph Huber 154134401deSJoseph Huber llvm::Error RPCServerTy::startThread() { 155134401deSJoseph Huber Thread->startThread(); 156134401deSJoseph Huber return Error::success(); 157134401deSJoseph Huber } 158134401deSJoseph Huber 159134401deSJoseph Huber llvm::Error RPCServerTy::shutDown() { 160134401deSJoseph Huber Thread->shutDown(); 161134401deSJoseph Huber return Error::success(); 162134401deSJoseph Huber } 163330d8983SJohannes Doerfert 164330d8983SJohannes Doerfert llvm::Expected<bool> 165330d8983SJohannes Doerfert RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device, 166330d8983SJohannes Doerfert plugin::GenericGlobalHandlerTy &Handler, 167330d8983SJohannes Doerfert plugin::DeviceImageTy &Image) { 16889d8e700SJoseph Huber return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client"); 169330d8983SJohannes Doerfert } 170330d8983SJohannes Doerfert 171330d8983SJohannes Doerfert Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, 172330d8983SJohannes Doerfert plugin::GenericGlobalHandlerTy &Handler, 173330d8983SJohannes Doerfert plugin::DeviceImageTy &Image) { 174330d8983SJohannes Doerfert uint64_t NumPorts = 175b4d49fb5SJoseph Huber std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); 176b4d49fb5SJoseph Huber void *RPCBuffer = Device.allocate( 177b4d49fb5SJoseph Huber rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr, 178b4d49fb5SJoseph Huber TARGET_ALLOC_HOST); 179b4d49fb5SJoseph Huber if (!RPCBuffer) 180330d8983SJohannes Doerfert return plugin::Plugin::error( 181b4d49fb5SJoseph Huber "Failed to initialize RPC server for device %d", Device.getDeviceId()); 182330d8983SJohannes Doerfert 183330d8983SJohannes Doerfert // Get the address of the RPC client from the device. 18489d8e700SJoseph Huber plugin::GlobalTy ClientGlobal("__llvm_rpc_client", sizeof(rpc::Client)); 185330d8983SJohannes Doerfert if (auto Err = 186330d8983SJohannes Doerfert Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal)) 187330d8983SJohannes Doerfert return Err; 188330d8983SJohannes Doerfert 189b4d49fb5SJoseph Huber rpc::Client client(NumPorts, RPCBuffer); 19089d8e700SJoseph Huber if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client, 19189d8e700SJoseph Huber sizeof(rpc::Client), nullptr)) 192330d8983SJohannes Doerfert return Err; 193b4d49fb5SJoseph Huber Buffers[Device.getDeviceId()] = RPCBuffer; 194134401deSJoseph Huber Devices[Device.getDeviceId()] = &Device; 195b4d49fb5SJoseph Huber 196b4d49fb5SJoseph Huber return Error::success(); 197330d8983SJohannes Doerfert } 198330d8983SJohannes Doerfert 199330d8983SJohannes Doerfert Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { 200b4d49fb5SJoseph Huber Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST); 201134401deSJoseph Huber Buffers[Device.getDeviceId()] = nullptr; 202134401deSJoseph Huber Devices[Device.getDeviceId()] = nullptr; 203b4d49fb5SJoseph Huber return Error::success(); 204330d8983SJohannes Doerfert } 205