1bdd1243dSDimitry Andric //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // This file implements logging infrastructure for extracting features and
10bdd1243dSDimitry Andric // rewards for mlgo policy training.
11bdd1243dSDimitry Andric //
12bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
13bdd1243dSDimitry Andric #include "llvm/Analysis/TensorSpec.h"
14bdd1243dSDimitry Andric #include "llvm/Config/config.h"
15bdd1243dSDimitry Andric
16bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h"
17bdd1243dSDimitry Andric #include "llvm/Analysis/Utils/TrainingLogger.h"
18bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h"
19bdd1243dSDimitry Andric #include "llvm/Support/Debug.h"
20bdd1243dSDimitry Andric #include "llvm/Support/JSON.h"
21bdd1243dSDimitry Andric #include "llvm/Support/MemoryBuffer.h"
22bdd1243dSDimitry Andric #include "llvm/Support/Path.h"
23bdd1243dSDimitry Andric #include "llvm/Support/raw_ostream.h"
24bdd1243dSDimitry Andric
25bdd1243dSDimitry Andric #include <cassert>
26bdd1243dSDimitry Andric #include <numeric>
27bdd1243dSDimitry Andric
28bdd1243dSDimitry Andric using namespace llvm;
29bdd1243dSDimitry Andric
writeHeader(std::optional<TensorSpec> AdviceSpec)30*06c3fb27SDimitry Andric void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
31bdd1243dSDimitry Andric json::OStream JOS(*OS);
32bdd1243dSDimitry Andric JOS.object([&]() {
33bdd1243dSDimitry Andric JOS.attributeArray("features", [&]() {
34bdd1243dSDimitry Andric for (const auto &TS : FeatureSpecs)
35bdd1243dSDimitry Andric TS.toJSON(JOS);
36bdd1243dSDimitry Andric });
37bdd1243dSDimitry Andric if (IncludeReward) {
38bdd1243dSDimitry Andric JOS.attributeBegin("score");
39bdd1243dSDimitry Andric RewardSpec.toJSON(JOS);
40bdd1243dSDimitry Andric JOS.attributeEnd();
41bdd1243dSDimitry Andric }
42*06c3fb27SDimitry Andric if (AdviceSpec.has_value()) {
43*06c3fb27SDimitry Andric JOS.attributeBegin("advice");
44*06c3fb27SDimitry Andric AdviceSpec->toJSON(JOS);
45*06c3fb27SDimitry Andric JOS.attributeEnd();
46*06c3fb27SDimitry Andric }
47bdd1243dSDimitry Andric });
48bdd1243dSDimitry Andric *OS << "\n";
49bdd1243dSDimitry Andric }
50bdd1243dSDimitry Andric
switchContext(StringRef Name)51bdd1243dSDimitry Andric void Logger::switchContext(StringRef Name) {
52bdd1243dSDimitry Andric CurrentContext = Name.str();
53bdd1243dSDimitry Andric json::OStream JOS(*OS);
54bdd1243dSDimitry Andric JOS.object([&]() { JOS.attribute("context", Name); });
55bdd1243dSDimitry Andric *OS << "\n";
56bdd1243dSDimitry Andric }
57bdd1243dSDimitry Andric
startObservation()58bdd1243dSDimitry Andric void Logger::startObservation() {
59bdd1243dSDimitry Andric auto I = ObservationIDs.insert({CurrentContext, 0});
60bdd1243dSDimitry Andric size_t NewObservationID = I.second ? 0 : ++I.first->second;
61bdd1243dSDimitry Andric json::OStream JOS(*OS);
62bdd1243dSDimitry Andric JOS.object([&]() {
63bdd1243dSDimitry Andric JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
64bdd1243dSDimitry Andric });
65bdd1243dSDimitry Andric *OS << "\n";
66bdd1243dSDimitry Andric }
67bdd1243dSDimitry Andric
endObservation()68bdd1243dSDimitry Andric void Logger::endObservation() { *OS << "\n"; }
69bdd1243dSDimitry Andric
logRewardImpl(const char * RawData)70bdd1243dSDimitry Andric void Logger::logRewardImpl(const char *RawData) {
71bdd1243dSDimitry Andric assert(IncludeReward);
72bdd1243dSDimitry Andric json::OStream JOS(*OS);
73bdd1243dSDimitry Andric JOS.object([&]() {
74bdd1243dSDimitry Andric JOS.attribute("outcome", static_cast<int64_t>(
75bdd1243dSDimitry Andric ObservationIDs.find(CurrentContext)->second));
76bdd1243dSDimitry Andric });
77bdd1243dSDimitry Andric *OS << "\n";
78bdd1243dSDimitry Andric writeTensor(RewardSpec, RawData);
79bdd1243dSDimitry Andric *OS << "\n";
80bdd1243dSDimitry Andric }
81bdd1243dSDimitry Andric
Logger(std::unique_ptr<raw_ostream> OS,const std::vector<TensorSpec> & FeatureSpecs,const TensorSpec & RewardSpec,bool IncludeReward,std::optional<TensorSpec> AdviceSpec)82bdd1243dSDimitry Andric Logger::Logger(std::unique_ptr<raw_ostream> OS,
83bdd1243dSDimitry Andric const std::vector<TensorSpec> &FeatureSpecs,
84*06c3fb27SDimitry Andric const TensorSpec &RewardSpec, bool IncludeReward,
85*06c3fb27SDimitry Andric std::optional<TensorSpec> AdviceSpec)
86bdd1243dSDimitry Andric : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
87bdd1243dSDimitry Andric IncludeReward(IncludeReward) {
88*06c3fb27SDimitry Andric writeHeader(AdviceSpec);
89bdd1243dSDimitry Andric }
90