xref: /llvm-project/llvm/lib/Analysis/TrainingLogger.cpp (revision cdfb51295d814a875925974364931ef4337641e1)
10cb9746aSMircea Trofin //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
20cb9746aSMircea Trofin //
30cb9746aSMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40cb9746aSMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
50cb9746aSMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60cb9746aSMircea Trofin //
70cb9746aSMircea Trofin //===----------------------------------------------------------------------===//
80cb9746aSMircea Trofin //
90cb9746aSMircea Trofin // This file implements logging infrastructure for extracting features and
100cb9746aSMircea Trofin // rewards for mlgo policy training.
110cb9746aSMircea Trofin //
120cb9746aSMircea Trofin //===----------------------------------------------------------------------===//
134c97745bSMircea Trofin #include "llvm/Analysis/TensorSpec.h"
140cb9746aSMircea Trofin #include "llvm/Config/config.h"
150cb9746aSMircea Trofin 
160cb9746aSMircea Trofin #include "llvm/ADT/Twine.h"
170cb9746aSMircea Trofin #include "llvm/Analysis/Utils/TrainingLogger.h"
180cb9746aSMircea Trofin #include "llvm/Support/CommandLine.h"
190cb9746aSMircea Trofin #include "llvm/Support/Debug.h"
200cb9746aSMircea Trofin #include "llvm/Support/JSON.h"
210cb9746aSMircea Trofin #include "llvm/Support/MemoryBuffer.h"
220cb9746aSMircea Trofin #include "llvm/Support/Path.h"
230cb9746aSMircea Trofin #include "llvm/Support/raw_ostream.h"
240cb9746aSMircea Trofin 
250cb9746aSMircea Trofin #include <cassert>
260cb9746aSMircea Trofin #include <numeric>
270cb9746aSMircea Trofin 
280cb9746aSMircea Trofin using namespace llvm;
290cb9746aSMircea Trofin 
writeHeader(std::optional<TensorSpec> AdviceSpec)30*35aa7374SMircea Trofin void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
316d11baf0SMircea Trofin   json::OStream JOS(*OS);
324c97745bSMircea Trofin   JOS.object([&]() {
334c97745bSMircea Trofin     JOS.attributeArray("features", [&]() {
349bd69ae8SMircea Trofin       for (const auto &TS : FeatureSpecs)
354c97745bSMircea Trofin         TS.toJSON(JOS);
364c97745bSMircea Trofin     });
374c97745bSMircea Trofin     if (IncludeReward) {
384c97745bSMircea Trofin       JOS.attributeBegin("score");
394c97745bSMircea Trofin       RewardSpec.toJSON(JOS);
404c97745bSMircea Trofin       JOS.attributeEnd();
414c97745bSMircea Trofin     }
42*35aa7374SMircea Trofin     if (AdviceSpec.has_value()) {
43*35aa7374SMircea Trofin       JOS.attributeBegin("advice");
44*35aa7374SMircea Trofin       AdviceSpec->toJSON(JOS);
45*35aa7374SMircea Trofin       JOS.attributeEnd();
46*35aa7374SMircea Trofin     }
474c97745bSMircea Trofin   });
486d11baf0SMircea Trofin   *OS << "\n";
494c97745bSMircea Trofin }
504c97745bSMircea Trofin 
switchContext(StringRef Name)516d11baf0SMircea Trofin void Logger::switchContext(StringRef Name) {
526d11baf0SMircea Trofin   CurrentContext = Name.str();
536d11baf0SMircea Trofin   json::OStream JOS(*OS);
544c97745bSMircea Trofin   JOS.object([&]() { JOS.attribute("context", Name); });
556d11baf0SMircea Trofin   *OS << "\n";
564c97745bSMircea Trofin }
574c97745bSMircea Trofin 
startObservation()586d11baf0SMircea Trofin void Logger::startObservation() {
596d11baf0SMircea Trofin   auto I = ObservationIDs.insert({CurrentContext, 0});
606d11baf0SMircea Trofin   size_t NewObservationID = I.second ? 0 : ++I.first->second;
616d11baf0SMircea Trofin   json::OStream JOS(*OS);
62d581308dSMircea Trofin   JOS.object([&]() {
636d11baf0SMircea Trofin     JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
64d581308dSMircea Trofin   });
656d11baf0SMircea Trofin   *OS << "\n";
664c97745bSMircea Trofin }
674c97745bSMircea Trofin 
endObservation()686d11baf0SMircea Trofin void Logger::endObservation() { *OS << "\n"; }
696d11baf0SMircea Trofin 
logRewardImpl(const char * RawData)706d11baf0SMircea Trofin void Logger::logRewardImpl(const char *RawData) {
716d11baf0SMircea Trofin   assert(IncludeReward);
726d11baf0SMircea Trofin   json::OStream JOS(*OS);
736d11baf0SMircea Trofin   JOS.object([&]() {
746d11baf0SMircea Trofin     JOS.attribute("outcome", static_cast<int64_t>(
756d11baf0SMircea Trofin                                  ObservationIDs.find(CurrentContext)->second));
766d11baf0SMircea Trofin   });
776d11baf0SMircea Trofin   *OS << "\n";
786d11baf0SMircea Trofin   writeTensor(RewardSpec, RawData);
796d11baf0SMircea Trofin   *OS << "\n";
804c97745bSMircea Trofin }
814c97745bSMircea Trofin 
Logger(std::unique_ptr<raw_ostream> OS,const std::vector<TensorSpec> & FeatureSpecs,const TensorSpec & RewardSpec,bool IncludeReward,std::optional<TensorSpec> AdviceSpec)826d11baf0SMircea Trofin Logger::Logger(std::unique_ptr<raw_ostream> OS,
836d11baf0SMircea Trofin                const std::vector<TensorSpec> &FeatureSpecs,
84*35aa7374SMircea Trofin                const TensorSpec &RewardSpec, bool IncludeReward,
85*35aa7374SMircea Trofin                std::optional<TensorSpec> AdviceSpec)
866d11baf0SMircea Trofin     : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
876d11baf0SMircea Trofin       IncludeReward(IncludeReward) {
88*35aa7374SMircea Trofin   writeHeader(AdviceSpec);
890cb9746aSMircea Trofin }
90