1 //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logging infrastructure for extracting features and 10 // rewards for mlgo policy training. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/Config/config.h" 14 #if defined(LLVM_HAVE_TF_API) 15 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Analysis/Utils/TrainingLogger.h" 18 #include "llvm/Support/Base64.h" 19 #include "llvm/Support/CommandLine.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/JSON.h" 22 #include "llvm/Support/MemoryBuffer.h" 23 #include "llvm/Support/Path.h" 24 #include "llvm/Support/raw_ostream.h" 25 26 #include "google/protobuf/struct.pb.h" 27 #include "google/protobuf/text_format.h" 28 #include "tensorflow/core/example/example.pb.h" 29 #include <cassert> 30 #include <numeric> 31 32 using namespace llvm; 33 34 using google::protobuf::Message; 35 using google::protobuf::TextFormat; 36 37 static cl::opt<bool> 38 ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden, 39 cl::desc("Output textual (human-readable) protobuf.")); 40 41 namespace { 42 43 void serialize(const Message &SE, std::string *OutStr) { 44 if (ProtobufTextMode) { 45 TextFormat::PrintToString(SE, OutStr); 46 } else { 47 *OutStr = SE.SerializeAsString(); 48 } 49 } 50 } // namespace 51 52 namespace llvm { 53 54 class LoggerDataImpl { 55 const std::vector<TensorSpec> LoggedFeatureSpecs; 56 const TensorSpec RewardSpec; 57 const bool IncludeReward; 58 59 std::vector<tensorflow::FeatureList> FeatureLists; 60 tensorflow::FeatureList Reward; 61 62 bool isSelfConsistent(const tensorflow::SequenceExample &SE, 63 size_t NrRecords) const { 64 bool Ret = true; 65 for (const auto &TSpecs : LoggedFeatureSpecs) { 66 const auto &Name = TSpecs.name(); 67 const auto &FL = SE.feature_lists().feature_list().at(Name).feature(); 68 if (NrRecords != static_cast<size_t>(FL.size())) { 69 dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected " 70 << NrRecords << " got " << FL.size() << "\n"; 71 Ret = false; 72 } 73 } 74 if (IncludeReward && static_cast<size_t>(SE.feature_lists() 75 .feature_list() 76 .at(RewardSpec.name()) 77 .feature() 78 .size()) != NrRecords) { 79 dbgs() << "[TF-UTILS]: reward is missing records.\n"; 80 Ret = false; 81 } 82 return Ret; 83 } 84 85 void transferLog(tensorflow::SequenceExample &SE) { 86 auto *FL = SE.mutable_feature_lists()->mutable_feature_list(); 87 if (IncludeReward) 88 (*FL)[RewardSpec.name()] = std::move(Reward); 89 assert(FeatureLists.size() == LoggedFeatureSpecs.size()); 90 for (size_t I = 0; I < FeatureLists.size(); ++I) { 91 const auto &LFS = LoggedFeatureSpecs[I]; 92 (*FL)[LFS.name()] = std::move(FeatureLists[I]); 93 } 94 } 95 96 public: 97 LoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs, 98 const TensorSpec &RewardSpec, bool IncludeReward) 99 : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec), 100 IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {} 101 102 // flush the logged info to a stream and clear the log contents. 103 void flush(std::string *Str) { 104 size_t NrRecords = getNrRecords(); 105 (void)NrRecords; 106 tensorflow::SequenceExample SE; 107 transferLog(SE); 108 assert(isSelfConsistent(SE, NrRecords)); 109 serialize(SE, Str); 110 } 111 112 char *addNewTensor(size_t FeatureID) { 113 const auto &Spec = LoggedFeatureSpecs[FeatureID]; 114 if (Spec.isElementType<float>()) { 115 auto *RF = FeatureLists[FeatureID] 116 .add_feature() 117 ->mutable_float_list() 118 ->mutable_value(); 119 RF->Resize(Spec.getElementCount(), 0.0); 120 return reinterpret_cast<char *>(RF->mutable_data()); 121 } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) { 122 auto *RF = FeatureLists[FeatureID] 123 .add_feature() 124 ->mutable_int64_list() 125 ->mutable_value(); 126 RF->Resize(Spec.getElementCount(), 0); 127 return reinterpret_cast<char *>(RF->mutable_data()); 128 } 129 llvm_unreachable("Unsupported tensor type."); 130 } 131 132 template <typename T> void logReward(T Value) { 133 assert(IncludeReward); 134 if (RewardSpec.isElementType<float>()) 135 Reward.add_feature()->mutable_float_list()->add_value(Value); 136 else if (RewardSpec.isElementType<int32_t>() || 137 RewardSpec.isElementType<int64_t>()) 138 Reward.add_feature()->mutable_int64_list()->add_value(Value); 139 else 140 llvm_unreachable("Unsupported tensor type."); 141 } 142 143 size_t getNrRecords() const { 144 return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size(); 145 } 146 }; 147 } // namespace llvm 148 149 Logger::Logger(const std::vector<TensorSpec> &FeatureSpecs, 150 const TensorSpec &RewardSpec, bool IncludeReward) 151 : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), 152 IncludeReward(IncludeReward), 153 LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec, 154 IncludeReward)) {} 155 156 Logger::~Logger() {} 157 158 #define LOG_REWARD(NAME, TYPE) \ 159 void Logger::log##NAME##Reward(TYPE Value) { \ 160 assert(IncludeReward); \ 161 LoggerData->logReward(Value); \ 162 } 163 164 LOG_REWARD(Float, float) 165 LOG_REWARD(Int32, int32_t) 166 LOG_REWARD(Int64, int64_t) 167 #undef LOG_REWARD 168 169 #define LOG_FINAL_REWARD(NAME, TYPE) \ 170 void Logger::log##NAME##FinalReward(TYPE Value) { \ 171 assert(RewardSpec.isElementType<TYPE>()); \ 172 for (size_t I = 1; I < LoggerData->getNrRecords(); ++I) \ 173 log##NAME##Reward(0); \ 174 log##NAME##Reward(Value); \ 175 } 176 177 LOG_FINAL_REWARD(Float, float) 178 LOG_FINAL_REWARD(Int32, int32_t) 179 LOG_FINAL_REWARD(Int64, int64_t) 180 #undef LOG_FINAL_REWARD 181 182 void Logger::logFloatValue(size_t FeatureID, const float *Value) { 183 assert(FeatureSpecs[FeatureID].isElementType<float>()); 184 logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value)); 185 } 186 187 void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) { 188 assert(FeatureSpecs[FeatureID].isElementType<int64_t>()); 189 logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value)); 190 } 191 192 void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) { 193 assert(FeatureSpecs[FeatureID].isElementType<int32_t>()); 194 logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value)); 195 } 196 197 void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) { 198 const auto &Spec = FeatureSpecs[FeatureID]; 199 char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID); 200 if (Spec.isElementType<int32_t>()) 201 for (size_t I = 0; I < Spec.getElementCount(); ++I) 202 (reinterpret_cast<int64_t *>(Buff))[I] = 203 static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]); 204 else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>()) 205 std::memcpy(Buff, RawData, 206 Spec.getElementCount() * Spec.getElementByteSize()); 207 else 208 llvm_unreachable("Unsupported tensor type"); 209 } 210 211 char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) { 212 return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID)); 213 } 214 215 void Logger::flush(std::string *Str) { LoggerData->flush(Str); } 216 217 void Logger::flush(raw_ostream &OS) { 218 std::string Buff; 219 LoggerData->flush(&Buff); 220 OS << Buff; 221 } 222 223 void Logger::flushLogs(raw_ostream &OS, 224 const StringMap<std::unique_ptr<Logger>> &Loggers) { 225 google::protobuf::Struct Msg; 226 for (const auto &NamedLogger : Loggers) { 227 tensorflow::SequenceExample SE; 228 const auto &Logger = NamedLogger.second; 229 std::string Unencoded; 230 if (Logger->LoggerData->getNrRecords() > 0) 231 Logger->flush(&Unencoded); 232 233 (*Msg.mutable_fields())[NamedLogger.first().str()] 234 .mutable_string_value() 235 ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded)); 236 } 237 238 std::string OutStr; 239 serialize(Msg, &OutStr); 240 OS << OutStr; 241 } 242 #endif // defined(LLVM_HAVE_TF_API) 243