1 //===- RawByteChannel.h -----------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H 10 #define LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H 11 12 #include "llvm/ADT/StringRef.h" 13 #include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" 14 #include "llvm/Support/Endian.h" 15 #include "llvm/Support/Error.h" 16 #include <cstdint> 17 #include <mutex> 18 #include <string> 19 #include <type_traits> 20 21 namespace llvm { 22 namespace orc { 23 namespace shared { 24 25 /// Interface for byte-streams to be used with ORC Serialization. 26 class RawByteChannel { 27 public: 28 virtual ~RawByteChannel() = default; 29 30 /// Read Size bytes from the stream into *Dst. 31 virtual Error readBytes(char *Dst, unsigned Size) = 0; 32 33 /// Read size bytes from *Src and append them to the stream. 34 virtual Error appendBytes(const char *Src, unsigned Size) = 0; 35 36 /// Flush the stream if possible. 37 virtual Error send() = 0; 38 39 /// Notify the channel that we're starting a message send. 40 /// Locks the channel for writing. 41 template <typename FunctionIdT, typename SequenceIdT> startSendMessage(const FunctionIdT & FnId,const SequenceIdT & SeqNo)42 Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { 43 writeLock.lock(); 44 if (auto Err = serializeSeq(*this, FnId, SeqNo)) { 45 writeLock.unlock(); 46 return Err; 47 } 48 return Error::success(); 49 } 50 51 /// Notify the channel that we're ending a message send. 52 /// Unlocks the channel for writing. endSendMessage()53 Error endSendMessage() { 54 writeLock.unlock(); 55 return Error::success(); 56 } 57 58 /// Notify the channel that we're starting a message receive. 59 /// Locks the channel for reading. 60 template <typename FunctionIdT, typename SequenceNumberT> startReceiveMessage(FunctionIdT & FnId,SequenceNumberT & SeqNo)61 Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { 62 readLock.lock(); 63 if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { 64 readLock.unlock(); 65 return Err; 66 } 67 return Error::success(); 68 } 69 70 /// Notify the channel that we're ending a message receive. 71 /// Unlocks the channel for reading. endReceiveMessage()72 Error endReceiveMessage() { 73 readLock.unlock(); 74 return Error::success(); 75 } 76 77 /// Get the lock for stream reading. getReadLock()78 std::mutex &getReadLock() { return readLock; } 79 80 /// Get the lock for stream writing. getWriteLock()81 std::mutex &getWriteLock() { return writeLock; } 82 83 private: 84 std::mutex readLock, writeLock; 85 }; 86 87 template <typename ChannelT, typename T> 88 class SerializationTraits< 89 ChannelT, T, T, 90 std::enable_if_t< 91 std::is_base_of<RawByteChannel, ChannelT>::value && 92 (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || 93 std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || 94 std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || 95 std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || 96 std::is_same<T, char>::value)>> { 97 public: serialize(ChannelT & C,T V)98 static Error serialize(ChannelT &C, T V) { 99 support::endian::byte_swap<T, support::big>(V); 100 return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); 101 }; 102 deserialize(ChannelT & C,T & V)103 static Error deserialize(ChannelT &C, T &V) { 104 if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) 105 return Err; 106 support::endian::byte_swap<T, support::big>(V); 107 return Error::success(); 108 }; 109 }; 110 111 template <typename ChannelT> 112 class SerializationTraits< 113 ChannelT, bool, bool, 114 std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { 115 public: serialize(ChannelT & C,bool V)116 static Error serialize(ChannelT &C, bool V) { 117 uint8_t Tmp = V ? 1 : 0; 118 if (auto Err = C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) 119 return Err; 120 return Error::success(); 121 } 122 deserialize(ChannelT & C,bool & V)123 static Error deserialize(ChannelT &C, bool &V) { 124 uint8_t Tmp = 0; 125 if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) 126 return Err; 127 V = Tmp != 0; 128 return Error::success(); 129 } 130 }; 131 132 template <typename ChannelT> 133 class SerializationTraits< 134 ChannelT, std::string, StringRef, 135 std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { 136 public: 137 /// Serialization channel serialization for std::strings. serialize(RawByteChannel & C,StringRef S)138 static Error serialize(RawByteChannel &C, StringRef S) { 139 if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) 140 return Err; 141 return C.appendBytes((const char *)S.data(), S.size()); 142 } 143 }; 144 145 template <typename ChannelT, typename T> 146 class SerializationTraits< 147 ChannelT, std::string, T, 148 std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value && 149 (std::is_same<T, const char *>::value || 150 std::is_same<T, char *>::value)>> { 151 public: serialize(RawByteChannel & C,const char * S)152 static Error serialize(RawByteChannel &C, const char *S) { 153 return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, 154 S); 155 } 156 }; 157 158 template <typename ChannelT> 159 class SerializationTraits< 160 ChannelT, std::string, std::string, 161 std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { 162 public: 163 /// Serialization channel serialization for std::strings. serialize(RawByteChannel & C,const std::string & S)164 static Error serialize(RawByteChannel &C, const std::string &S) { 165 return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, 166 S); 167 } 168 169 /// Serialization channel deserialization for std::strings. deserialize(RawByteChannel & C,std::string & S)170 static Error deserialize(RawByteChannel &C, std::string &S) { 171 uint64_t Count = 0; 172 if (auto Err = deserializeSeq(C, Count)) 173 return Err; 174 S.resize(Count); 175 return C.readBytes(&S[0], Count); 176 } 177 }; 178 179 } // end namespace shared 180 } // end namespace orc 181 } // end namespace llvm 182 183 #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H 184