1 //===- RPC.h - Interface for remote procedure calls from the GPU ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "RPC.h" 10 11 #include "Shared/Debug.h" 12 #include "Shared/RPCOpcodes.h" 13 14 #include "PluginInterface.h" 15 16 #include "shared/rpc.h" 17 #include "shared/rpc_opcodes.h" 18 19 using namespace llvm; 20 using namespace omp; 21 using namespace target; 22 23 template <uint32_t NumLanes> 24 rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, 25 rpc::Server::Port &Port) { 26 27 switch (Port.get_opcode()) { 28 case LIBC_MALLOC: { 29 Port.recv_and_send([&](rpc::Buffer *Buffer, uint32_t) { 30 Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate( 31 Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING)); 32 }); 33 break; 34 } 35 case LIBC_FREE: { 36 Port.recv([&](rpc::Buffer *Buffer, uint32_t) { 37 Device.free(reinterpret_cast<void *>(Buffer->data[0]), 38 TARGET_ALLOC_DEVICE_NON_BLOCKING); 39 }); 40 break; 41 } 42 case OFFLOAD_HOST_CALL: { 43 uint64_t Sizes[NumLanes] = {0}; 44 unsigned long long Results[NumLanes] = {0}; 45 void *Args[NumLanes] = {nullptr}; 46 Port.recv_n(Args, Sizes, [&](uint64_t Size) { return new char[Size]; }); 47 Port.recv([&](rpc::Buffer *buffer, uint32_t ID) { 48 using FuncPtrTy = unsigned long long (*)(void *); 49 auto Func = reinterpret_cast<FuncPtrTy>(buffer->data[0]); 50 Results[ID] = Func(Args[ID]); 51 }); 52 Port.send([&](rpc::Buffer *Buffer, uint32_t ID) { 53 Buffer->data[0] = static_cast<uint64_t>(Results[ID]); 54 delete[] reinterpret_cast<char *>(Args[ID]); 55 }); 56 break; 57 } 58 default: 59 return rpc::RPC_UNHANDLED_OPCODE; 60 break; 61 } 62 return rpc::RPC_SUCCESS; 63 } 64 65 static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, 66 rpc::Server::Port &Port, 67 uint32_t NumLanes) { 68 if (NumLanes == 1) 69 return handleOffloadOpcodes<1>(Device, Port); 70 else if (NumLanes == 32) 71 return handleOffloadOpcodes<32>(Device, Port); 72 else if (NumLanes == 64) 73 return handleOffloadOpcodes<64>(Device, Port); 74 else 75 return rpc::RPC_ERROR; 76 } 77 78 static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) { 79 uint64_t NumPorts = 80 std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); 81 rpc::Server Server(NumPorts, Buffer); 82 83 auto Port = Server.try_open(Device.getWarpSize()); 84 if (!Port) 85 return rpc::RPC_SUCCESS; 86 87 rpc::Status Status = 88 handleOffloadOpcodes(Device, *Port, Device.getWarpSize()); 89 90 // Let the `libc` library handle any other unhandled opcodes. 91 #ifdef LIBOMPTARGET_RPC_SUPPORT 92 if (Status == rpc::RPC_UNHANDLED_OPCODE) 93 Status = handle_libc_opcodes(*Port, Device.getWarpSize()); 94 #endif 95 96 Port->close(); 97 98 return Status; 99 } 100 101 void RPCServerTy::ServerThread::startThread() { 102 assert(!Running.load(std::memory_order_relaxed) && 103 "Attempting to start thread that is already running"); 104 Running.store(true, std::memory_order_release); 105 Worker = std::thread([this]() { run(); }); 106 } 107 108 void RPCServerTy::ServerThread::shutDown() { 109 assert(Running.load(std::memory_order_relaxed) && 110 "Attempting to shut down a thread that is not running"); 111 { 112 std::lock_guard<decltype(Mutex)> Lock(Mutex); 113 Running.store(false, std::memory_order_release); 114 CV.notify_all(); 115 } 116 if (Worker.joinable()) 117 Worker.join(); 118 } 119 120 void RPCServerTy::ServerThread::run() { 121 std::unique_lock<decltype(Mutex)> Lock(Mutex); 122 for (;;) { 123 CV.wait(Lock, [&]() { 124 return NumUsers.load(std::memory_order_acquire) > 0 || 125 !Running.load(std::memory_order_acquire); 126 }); 127 128 if (!Running.load(std::memory_order_acquire)) 129 return; 130 131 Lock.unlock(); 132 while (NumUsers.load(std::memory_order_relaxed) > 0 && 133 Running.load(std::memory_order_relaxed)) { 134 for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) { 135 if (!Buffer || !Device) 136 continue; 137 138 // If running the server failed, print a message but keep running. 139 if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS) 140 FAILURE_MESSAGE("Unhandled or invalid RPC opcode!"); 141 } 142 } 143 Lock.lock(); 144 } 145 } 146 147 RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin) 148 : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())), 149 Devices(std::make_unique<plugin::GenericDeviceTy *[]>( 150 Plugin.getNumDevices())), 151 Thread(new ServerThread(Buffers.get(), Devices.get(), 152 Plugin.getNumDevices())) {} 153 154 llvm::Error RPCServerTy::startThread() { 155 Thread->startThread(); 156 return Error::success(); 157 } 158 159 llvm::Error RPCServerTy::shutDown() { 160 Thread->shutDown(); 161 return Error::success(); 162 } 163 164 llvm::Expected<bool> 165 RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device, 166 plugin::GenericGlobalHandlerTy &Handler, 167 plugin::DeviceImageTy &Image) { 168 return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client"); 169 } 170 171 Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, 172 plugin::GenericGlobalHandlerTy &Handler, 173 plugin::DeviceImageTy &Image) { 174 uint64_t NumPorts = 175 std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); 176 void *RPCBuffer = Device.allocate( 177 rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr, 178 TARGET_ALLOC_HOST); 179 if (!RPCBuffer) 180 return plugin::Plugin::error( 181 "Failed to initialize RPC server for device %d", Device.getDeviceId()); 182 183 // Get the address of the RPC client from the device. 184 plugin::GlobalTy ClientGlobal("__llvm_rpc_client", sizeof(rpc::Client)); 185 if (auto Err = 186 Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal)) 187 return Err; 188 189 rpc::Client client(NumPorts, RPCBuffer); 190 if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client, 191 sizeof(rpc::Client), nullptr)) 192 return Err; 193 Buffers[Device.getDeviceId()] = RPCBuffer; 194 Devices[Device.getDeviceId()] = &Device; 195 196 return Error::success(); 197 } 198 199 Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { 200 Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST); 201 Buffers[Device.getDeviceId()] = nullptr; 202 Devices[Device.getDeviceId()] = nullptr; 203 return Error::success(); 204 } 205