xref: /llvm-project/llvm/unittests/Analysis/MLModelRunnerTest.cpp (revision 89e6a288674c9fae33aeb5448c7b1fe782b2bf53)
1059e0347SMircea Trofin //===- MLModelRunnerTest.cpp - test for MLModelRunner ---------------------===//
2059e0347SMircea Trofin //
3059e0347SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4059e0347SMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
5059e0347SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6059e0347SMircea Trofin //
7059e0347SMircea Trofin //===----------------------------------------------------------------------===//
8059e0347SMircea Trofin 
9059e0347SMircea Trofin #include "llvm/Analysis/MLModelRunner.h"
10313b1a82SMircea Trofin #include "llvm/ADT/StringExtras.h"
115b8dc7c8SMircea Trofin #include "llvm/Analysis/InteractiveModelRunner.h"
12059e0347SMircea Trofin #include "llvm/Analysis/NoInferenceModelRunner.h"
13c35ad9eeSMircea Trofin #include "llvm/Analysis/ReleaseModeModelRunner.h"
14*89e6a288SDaniil Fukalov #include "llvm/Config/llvm-config.h" // for LLVM_ON_UNIX
155b8dc7c8SMircea Trofin #include "llvm/Support/BinaryByteStream.h"
16313b1a82SMircea Trofin #include "llvm/Support/ErrorHandling.h"
1783051c5aSMircea Trofin #include "llvm/Support/FileSystem.h"
185b8dc7c8SMircea Trofin #include "llvm/Support/FileUtilities.h"
195b8dc7c8SMircea Trofin #include "llvm/Support/JSON.h"
2083051c5aSMircea Trofin #include "llvm/Support/Path.h"
215b8dc7c8SMircea Trofin #include "llvm/Support/raw_ostream.h"
2283051c5aSMircea Trofin #include "llvm/Testing/Support/SupportHelpers.h"
23059e0347SMircea Trofin #include "gtest/gtest.h"
245b8dc7c8SMircea Trofin #include <atomic>
255b8dc7c8SMircea Trofin #include <thread>
265b8dc7c8SMircea Trofin 
27059e0347SMircea Trofin using namespace llvm;
28059e0347SMircea Trofin 
29c35ad9eeSMircea Trofin namespace llvm {
30c35ad9eeSMircea Trofin // This is a mock of the kind of AOT-generated model evaluator. It has 2 tensors
31c35ad9eeSMircea Trofin // of shape {1}, and 'evaluation' adds them.
32c35ad9eeSMircea Trofin // The interface is the one expected by ReleaseModelRunner.
33313b1a82SMircea Trofin class MockAOTModelBase {
34313b1a82SMircea Trofin protected:
35c35ad9eeSMircea Trofin   int64_t A = 0;
36c35ad9eeSMircea Trofin   int64_t B = 0;
37c35ad9eeSMircea Trofin   int64_t R = 0;
38c35ad9eeSMircea Trofin 
39c35ad9eeSMircea Trofin public:
40313b1a82SMircea Trofin   MockAOTModelBase() = default;
41313b1a82SMircea Trofin   virtual ~MockAOTModelBase() = default;
42313b1a82SMircea Trofin 
43313b1a82SMircea Trofin   virtual int LookupArgIndex(const std::string &Name) {
44c35ad9eeSMircea Trofin     if (Name == "prefix_a")
45c35ad9eeSMircea Trofin       return 0;
46c35ad9eeSMircea Trofin     if (Name == "prefix_b")
47c35ad9eeSMircea Trofin       return 1;
48c35ad9eeSMircea Trofin     return -1;
49c35ad9eeSMircea Trofin   }
50c35ad9eeSMircea Trofin   int LookupResultIndex(const std::string &) { return 0; }
51313b1a82SMircea Trofin   virtual void Run() = 0;
52313b1a82SMircea Trofin   virtual void *result_data(int RIndex) {
53c35ad9eeSMircea Trofin     if (RIndex == 0)
54c35ad9eeSMircea Trofin       return &R;
55c35ad9eeSMircea Trofin     return nullptr;
56c35ad9eeSMircea Trofin   }
57313b1a82SMircea Trofin   virtual void *arg_data(int Index) {
58c35ad9eeSMircea Trofin     switch (Index) {
59c35ad9eeSMircea Trofin     case 0:
60c35ad9eeSMircea Trofin       return &A;
61c35ad9eeSMircea Trofin     case 1:
62c35ad9eeSMircea Trofin       return &B;
63c35ad9eeSMircea Trofin     default:
64c35ad9eeSMircea Trofin       return nullptr;
65c35ad9eeSMircea Trofin     }
66c35ad9eeSMircea Trofin   }
67c35ad9eeSMircea Trofin };
68313b1a82SMircea Trofin 
69313b1a82SMircea Trofin class AdditionAOTModel final : public MockAOTModelBase {
70313b1a82SMircea Trofin public:
71313b1a82SMircea Trofin   AdditionAOTModel() = default;
72313b1a82SMircea Trofin   void Run() override { R = A + B; }
73313b1a82SMircea Trofin };
74313b1a82SMircea Trofin 
75313b1a82SMircea Trofin class DiffAOTModel final : public MockAOTModelBase {
76313b1a82SMircea Trofin public:
77313b1a82SMircea Trofin   DiffAOTModel() = default;
78313b1a82SMircea Trofin   void Run() override { R = A - B; }
79313b1a82SMircea Trofin };
80313b1a82SMircea Trofin 
81313b1a82SMircea Trofin static const char *M1Selector = "the model that subtracts";
82313b1a82SMircea Trofin static const char *M2Selector = "the model that adds";
83313b1a82SMircea Trofin 
84313b1a82SMircea Trofin static MD5::MD5Result Hash1 = MD5::hash(arrayRefFromStringRef(M1Selector));
85313b1a82SMircea Trofin static MD5::MD5Result Hash2 = MD5::hash(arrayRefFromStringRef(M2Selector));
86313b1a82SMircea Trofin class ComposedAOTModel final {
87313b1a82SMircea Trofin   DiffAOTModel M1;
88313b1a82SMircea Trofin   AdditionAOTModel M2;
89313b1a82SMircea Trofin   uint64_t Selector[2] = {0};
90313b1a82SMircea Trofin 
91313b1a82SMircea Trofin   bool isHashSameAsSelector(const std::pair<uint64_t, uint64_t> &Words) const {
92313b1a82SMircea Trofin     return Selector[0] == Words.first && Selector[1] == Words.second;
93313b1a82SMircea Trofin   }
94313b1a82SMircea Trofin   MockAOTModelBase *getModel() {
95313b1a82SMircea Trofin     if (isHashSameAsSelector(Hash1.words()))
96313b1a82SMircea Trofin       return &M1;
97313b1a82SMircea Trofin     if (isHashSameAsSelector(Hash2.words()))
98313b1a82SMircea Trofin       return &M2;
99313b1a82SMircea Trofin     llvm_unreachable("Should be one of the two");
100313b1a82SMircea Trofin   }
101313b1a82SMircea Trofin 
102313b1a82SMircea Trofin public:
103313b1a82SMircea Trofin   ComposedAOTModel() = default;
104313b1a82SMircea Trofin   int LookupArgIndex(const std::string &Name) {
1053a462d89SMircea Trofin     if (Name == "prefix_model_selector")
106313b1a82SMircea Trofin       return 2;
107313b1a82SMircea Trofin     return getModel()->LookupArgIndex(Name);
108313b1a82SMircea Trofin   }
109313b1a82SMircea Trofin   int LookupResultIndex(const std::string &Name) {
110313b1a82SMircea Trofin     return getModel()->LookupResultIndex(Name);
111313b1a82SMircea Trofin   }
112313b1a82SMircea Trofin   void *arg_data(int Index) {
113313b1a82SMircea Trofin     if (Index == 2)
114313b1a82SMircea Trofin       return Selector;
115313b1a82SMircea Trofin     return getModel()->arg_data(Index);
116313b1a82SMircea Trofin   }
117313b1a82SMircea Trofin   void *result_data(int RIndex) { return getModel()->result_data(RIndex); }
118313b1a82SMircea Trofin   void Run() { getModel()->Run(); }
119313b1a82SMircea Trofin };
120313b1a82SMircea Trofin 
121313b1a82SMircea Trofin static EmbeddedModelRunnerOptions makeOptions() {
122313b1a82SMircea Trofin   EmbeddedModelRunnerOptions Opts;
123313b1a82SMircea Trofin   Opts.setFeedPrefix("prefix_");
124313b1a82SMircea Trofin   return Opts;
125313b1a82SMircea Trofin }
126c35ad9eeSMircea Trofin } // namespace llvm
127c35ad9eeSMircea Trofin 
128059e0347SMircea Trofin TEST(NoInferenceModelRunner, AccessTensors) {
129059e0347SMircea Trofin   const std::vector<TensorSpec> Inputs{
130059e0347SMircea Trofin       TensorSpec::createSpec<int64_t>("F1", {1}),
131059e0347SMircea Trofin       TensorSpec::createSpec<int64_t>("F2", {10}),
132059e0347SMircea Trofin       TensorSpec::createSpec<float>("F2", {5}),
133059e0347SMircea Trofin   };
134059e0347SMircea Trofin   LLVMContext Ctx;
135059e0347SMircea Trofin   NoInferenceModelRunner NIMR(Ctx, Inputs);
136059e0347SMircea Trofin   NIMR.getTensor<int64_t>(0)[0] = 1;
137059e0347SMircea Trofin   std::memcpy(NIMR.getTensor<int64_t>(1),
138059e0347SMircea Trofin               std::vector<int64_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}.data(),
139059e0347SMircea Trofin               10 * sizeof(int64_t));
140059e0347SMircea Trofin   std::memcpy(NIMR.getTensor<float>(2),
141345ed58eSSimon Pilgrim               std::vector<float>{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}.data(),
142059e0347SMircea Trofin               5 * sizeof(float));
143059e0347SMircea Trofin   ASSERT_EQ(NIMR.getTensor<int64_t>(0)[0], 1);
144059e0347SMircea Trofin   ASSERT_EQ(NIMR.getTensor<int64_t>(1)[8], 9);
145059e0347SMircea Trofin   ASSERT_EQ(NIMR.getTensor<float>(2)[1], 0.2f);
146059e0347SMircea Trofin }
147c35ad9eeSMircea Trofin 
148c35ad9eeSMircea Trofin TEST(ReleaseModeRunner, NormalUse) {
149c35ad9eeSMircea Trofin   LLVMContext Ctx;
150c35ad9eeSMircea Trofin   std::vector<TensorSpec> Inputs{TensorSpec::createSpec<int64_t>("a", {1}),
151c35ad9eeSMircea Trofin                                  TensorSpec::createSpec<int64_t>("b", {1})};
152313b1a82SMircea Trofin   auto Evaluator = std::make_unique<ReleaseModeModelRunner<AdditionAOTModel>>(
153313b1a82SMircea Trofin       Ctx, Inputs, "", makeOptions());
154c35ad9eeSMircea Trofin   *Evaluator->getTensor<int64_t>(0) = 1;
155c35ad9eeSMircea Trofin   *Evaluator->getTensor<int64_t>(1) = 2;
156c35ad9eeSMircea Trofin   EXPECT_EQ(Evaluator->evaluate<int64_t>(), 3);
157c35ad9eeSMircea Trofin   EXPECT_EQ(*Evaluator->getTensor<int64_t>(0), 1);
158c35ad9eeSMircea Trofin   EXPECT_EQ(*Evaluator->getTensor<int64_t>(1), 2);
159c35ad9eeSMircea Trofin }
160c35ad9eeSMircea Trofin 
161c35ad9eeSMircea Trofin TEST(ReleaseModeRunner, ExtraFeatures) {
162c35ad9eeSMircea Trofin   LLVMContext Ctx;
163c35ad9eeSMircea Trofin   std::vector<TensorSpec> Inputs{TensorSpec::createSpec<int64_t>("a", {1}),
164c35ad9eeSMircea Trofin                                  TensorSpec::createSpec<int64_t>("b", {1}),
165c35ad9eeSMircea Trofin                                  TensorSpec::createSpec<int64_t>("c", {1})};
166313b1a82SMircea Trofin   auto Evaluator = std::make_unique<ReleaseModeModelRunner<AdditionAOTModel>>(
167313b1a82SMircea Trofin       Ctx, Inputs, "", makeOptions());
168c35ad9eeSMircea Trofin   *Evaluator->getTensor<int64_t>(0) = 1;
169c35ad9eeSMircea Trofin   *Evaluator->getTensor<int64_t>(1) = 2;
170c35ad9eeSMircea Trofin   *Evaluator->getTensor<int64_t>(2) = -3;
171c35ad9eeSMircea Trofin   EXPECT_EQ(Evaluator->evaluate<int64_t>(), 3);
172c35ad9eeSMircea Trofin   EXPECT_EQ(*Evaluator->getTensor<int64_t>(0), 1);
173c35ad9eeSMircea Trofin   EXPECT_EQ(*Evaluator->getTensor<int64_t>(1), 2);
174c35ad9eeSMircea Trofin   EXPECT_EQ(*Evaluator->getTensor<int64_t>(2), -3);
175c35ad9eeSMircea Trofin }
176c35ad9eeSMircea Trofin 
177c35ad9eeSMircea Trofin TEST(ReleaseModeRunner, ExtraFeaturesOutOfOrder) {
178c35ad9eeSMircea Trofin   LLVMContext Ctx;
179c35ad9eeSMircea Trofin   std::vector<TensorSpec> Inputs{
180c35ad9eeSMircea Trofin       TensorSpec::createSpec<int64_t>("a", {1}),
181c35ad9eeSMircea Trofin       TensorSpec::createSpec<int64_t>("c", {1}),
182c35ad9eeSMircea Trofin       TensorSpec::createSpec<int64_t>("b", {1}),
183c35ad9eeSMircea Trofin   };
184313b1a82SMircea Trofin   auto Evaluator = std::make_unique<ReleaseModeModelRunner<AdditionAOTModel>>(
185313b1a82SMircea Trofin       Ctx, Inputs, "", makeOptions());
186c35ad9eeSMircea Trofin   *Evaluator->getTensor<int64_t>(0) = 1;         // a
187c35ad9eeSMircea Trofin   *Evaluator->getTensor<int64_t>(1) = 2;         // c
188c35ad9eeSMircea Trofin   *Evaluator->getTensor<int64_t>(2) = -3;        // b
189c35ad9eeSMircea Trofin   EXPECT_EQ(Evaluator->evaluate<int64_t>(), -2); // a + b
190c35ad9eeSMircea Trofin   EXPECT_EQ(*Evaluator->getTensor<int64_t>(0), 1);
191c35ad9eeSMircea Trofin   EXPECT_EQ(*Evaluator->getTensor<int64_t>(1), 2);
192c35ad9eeSMircea Trofin   EXPECT_EQ(*Evaluator->getTensor<int64_t>(2), -3);
193c35ad9eeSMircea Trofin }
1945b8dc7c8SMircea Trofin 
195313b1a82SMircea Trofin // We expect an error to be reported early if the user tried to specify a model
196313b1a82SMircea Trofin // selector, but the model in fact doesn't support that.
197313b1a82SMircea Trofin TEST(ReleaseModelRunner, ModelSelectorNoInputFeaturePresent) {
198313b1a82SMircea Trofin   LLVMContext Ctx;
199313b1a82SMircea Trofin   std::vector<TensorSpec> Inputs{TensorSpec::createSpec<int64_t>("a", {1}),
200313b1a82SMircea Trofin                                  TensorSpec::createSpec<int64_t>("b", {1})};
20113be6ee7SSimon Pilgrim   EXPECT_DEATH((void)std::make_unique<ReleaseModeModelRunner<AdditionAOTModel>>(
202313b1a82SMircea Trofin                    Ctx, Inputs, "", makeOptions().setModelSelector(M2Selector)),
203313b1a82SMircea Trofin                "A model selector was specified but the underlying model does "
2043a462d89SMircea Trofin                "not expose a model_selector input");
205313b1a82SMircea Trofin }
206313b1a82SMircea Trofin 
207313b1a82SMircea Trofin TEST(ReleaseModelRunner, ModelSelectorNoSelectorGiven) {
208313b1a82SMircea Trofin   LLVMContext Ctx;
209313b1a82SMircea Trofin   std::vector<TensorSpec> Inputs{TensorSpec::createSpec<int64_t>("a", {1}),
210313b1a82SMircea Trofin                                  TensorSpec::createSpec<int64_t>("b", {1})};
211313b1a82SMircea Trofin   EXPECT_DEATH(
21213be6ee7SSimon Pilgrim       (void)std::make_unique<ReleaseModeModelRunner<ComposedAOTModel>>(
213313b1a82SMircea Trofin           Ctx, Inputs, "", makeOptions()),
214313b1a82SMircea Trofin       "A model selector was not specified but the underlying model requires "
2153a462d89SMircea Trofin       "selecting one because it exposes a model_selector input");
216313b1a82SMircea Trofin }
217313b1a82SMircea Trofin 
2183a462d89SMircea Trofin // Test that we correctly set up the model_selector tensor value. We are only
219313b1a82SMircea Trofin // responsbile for what happens if the user doesn't specify a value (but the
220313b1a82SMircea Trofin // model supports the feature), or if the user specifies one, and we correctly
221313b1a82SMircea Trofin // populate the tensor, and do so upfront (in case the model implementation
222313b1a82SMircea Trofin // needs that for subsequent tensor buffer lookups).
223313b1a82SMircea Trofin TEST(ReleaseModelRunner, ModelSelector) {
224313b1a82SMircea Trofin   LLVMContext Ctx;
225313b1a82SMircea Trofin   std::vector<TensorSpec> Inputs{TensorSpec::createSpec<int64_t>("a", {1}),
226313b1a82SMircea Trofin                                  TensorSpec::createSpec<int64_t>("b", {1})};
227313b1a82SMircea Trofin   // This explicitly asks for M1
228313b1a82SMircea Trofin   auto Evaluator = std::make_unique<ReleaseModeModelRunner<ComposedAOTModel>>(
229313b1a82SMircea Trofin       Ctx, Inputs, "", makeOptions().setModelSelector(M1Selector));
230313b1a82SMircea Trofin   *Evaluator->getTensor<int64_t>(0) = 1;
231313b1a82SMircea Trofin   *Evaluator->getTensor<int64_t>(1) = 2;
232313b1a82SMircea Trofin   EXPECT_EQ(Evaluator->evaluate<int64_t>(), -1);
233313b1a82SMircea Trofin 
234313b1a82SMircea Trofin   // Ask for M2
235313b1a82SMircea Trofin   Evaluator = std::make_unique<ReleaseModeModelRunner<ComposedAOTModel>>(
236313b1a82SMircea Trofin       Ctx, Inputs, "", makeOptions().setModelSelector(M2Selector));
237313b1a82SMircea Trofin   *Evaluator->getTensor<int64_t>(0) = 1;
238313b1a82SMircea Trofin   *Evaluator->getTensor<int64_t>(1) = 2;
239313b1a82SMircea Trofin   EXPECT_EQ(Evaluator->evaluate<int64_t>(), 3);
240313b1a82SMircea Trofin 
241313b1a82SMircea Trofin   // Asking for a model that's not supported isn't handled by our infra and we
242313b1a82SMircea Trofin   // expect the model implementation to fail at a point.
243313b1a82SMircea Trofin }
244313b1a82SMircea Trofin 
24583051c5aSMircea Trofin #if defined(LLVM_ON_UNIX)
2465b8dc7c8SMircea Trofin TEST(InteractiveModelRunner, Evaluation) {
2475b8dc7c8SMircea Trofin   LLVMContext Ctx;
2485b8dc7c8SMircea Trofin   // Test the interaction with an external advisor by asking for advice twice.
2495b8dc7c8SMircea Trofin   // Use simple values, since we use the Logger underneath, that's tested more
2505b8dc7c8SMircea Trofin   // extensively elsewhere.
2515b8dc7c8SMircea Trofin   std::vector<TensorSpec> Inputs{
2525b8dc7c8SMircea Trofin       TensorSpec::createSpec<int64_t>("a", {1}),
2535b8dc7c8SMircea Trofin       TensorSpec::createSpec<int64_t>("b", {1}),
2545b8dc7c8SMircea Trofin       TensorSpec::createSpec<int64_t>("c", {1}),
2555b8dc7c8SMircea Trofin   };
2565b8dc7c8SMircea Trofin   TensorSpec AdviceSpec = TensorSpec::createSpec<float>("advice", {1});
2575b8dc7c8SMircea Trofin 
2585b8dc7c8SMircea Trofin   // Create the 2 files. Ideally we'd create them as named pipes, but that's not
2595b8dc7c8SMircea Trofin   // quite supported by the generic API.
2605b8dc7c8SMircea Trofin   std::error_code EC;
26183051c5aSMircea Trofin   llvm::unittest::TempDir Tmp("tmpdir", /*Unique=*/true);
26283051c5aSMircea Trofin   SmallString<128> FromCompilerName(Tmp.path().begin(), Tmp.path().end());
26383051c5aSMircea Trofin   SmallString<128> ToCompilerName(Tmp.path().begin(), Tmp.path().end());
26483051c5aSMircea Trofin   sys::path::append(FromCompilerName, "InteractiveModelRunner_Evaluation.out");
26583051c5aSMircea Trofin   sys::path::append(ToCompilerName, "InteractiveModelRunner_Evaluation.in");
26683051c5aSMircea Trofin   EXPECT_EQ(::mkfifo(FromCompilerName.c_str(), 0666), 0);
26783051c5aSMircea Trofin   EXPECT_EQ(::mkfifo(ToCompilerName.c_str(), 0666), 0);
2685b8dc7c8SMircea Trofin 
2695b8dc7c8SMircea Trofin   FileRemover Cleanup1(FromCompilerName);
2705b8dc7c8SMircea Trofin   FileRemover Cleanup2(ToCompilerName);
2715b8dc7c8SMircea Trofin 
27283051c5aSMircea Trofin   // Since the evaluator sends the features over and then blocks waiting for
27383051c5aSMircea Trofin   // an answer, we must spawn a thread playing the role of the advisor / host:
27483051c5aSMircea Trofin   std::atomic<int> SeenObservations = 0;
27583051c5aSMircea Trofin   // Start the host first to make sure the pipes are being prepared. Otherwise
27683051c5aSMircea Trofin   // the evaluator will hang.
27783051c5aSMircea Trofin   std::thread Advisor([&]() {
27883051c5aSMircea Trofin     // Open the writer first. This is because the evaluator will try opening
27983051c5aSMircea Trofin     // the "input" pipe first. An alternative that avoids ordering is for the
28083051c5aSMircea Trofin     // host to open the pipes RW.
28183051c5aSMircea Trofin     raw_fd_ostream ToCompiler(ToCompilerName, EC);
28283051c5aSMircea Trofin     EXPECT_FALSE(EC);
283795910c2SMircea Trofin     int FromCompilerHandle = 0;
284795910c2SMircea Trofin     EXPECT_FALSE(
285795910c2SMircea Trofin         sys::fs::openFileForRead(FromCompilerName, FromCompilerHandle));
286795910c2SMircea Trofin     sys::fs::file_t FromCompiler =
287795910c2SMircea Trofin         sys::fs::convertFDToNativeFile(FromCompilerHandle);
28883051c5aSMircea Trofin     EXPECT_EQ(SeenObservations, 0);
2895b8dc7c8SMircea Trofin     // Helper to read headers and other json lines.
2905b8dc7c8SMircea Trofin     SmallVector<char, 1024> Buffer;
2915b8dc7c8SMircea Trofin     auto ReadLn = [&]() {
2925b8dc7c8SMircea Trofin       Buffer.clear();
2935b8dc7c8SMircea Trofin       while (true) {
2945b8dc7c8SMircea Trofin         char Chr = 0;
29583051c5aSMircea Trofin         auto ReadOrErr = sys::fs::readNativeFile(FromCompiler, {&Chr, 1});
29683051c5aSMircea Trofin         EXPECT_FALSE(ReadOrErr.takeError());
29783051c5aSMircea Trofin         if (!*ReadOrErr)
2985b8dc7c8SMircea Trofin           continue;
2995b8dc7c8SMircea Trofin         if (Chr == '\n')
3005b8dc7c8SMircea Trofin           return StringRef(Buffer.data(), Buffer.size());
3015b8dc7c8SMircea Trofin         Buffer.push_back(Chr);
3025b8dc7c8SMircea Trofin       }
3035b8dc7c8SMircea Trofin     };
3045b8dc7c8SMircea Trofin     // See include/llvm/Analysis/Utils/TrainingLogger.h
3055b8dc7c8SMircea Trofin     // First comes the header
3065b8dc7c8SMircea Trofin     auto Header = json::parse(ReadLn());
3075b8dc7c8SMircea Trofin     EXPECT_FALSE(Header.takeError());
3085b8dc7c8SMircea Trofin     EXPECT_NE(Header->getAsObject()->getArray("features"), nullptr);
30935aa7374SMircea Trofin     EXPECT_NE(Header->getAsObject()->getObject("advice"), nullptr);
3105b8dc7c8SMircea Trofin     // Then comes the context
3115b8dc7c8SMircea Trofin     EXPECT_FALSE(json::parse(ReadLn()).takeError());
3125b8dc7c8SMircea Trofin 
3135b8dc7c8SMircea Trofin     int64_t Features[3] = {0};
3145b8dc7c8SMircea Trofin     auto FullyRead = [&]() {
3155b8dc7c8SMircea Trofin       size_t InsPt = 0;
3165b8dc7c8SMircea Trofin       const size_t ToRead = 3 * Inputs[0].getTotalTensorBufferSize();
3175b8dc7c8SMircea Trofin       char *Buff = reinterpret_cast<char *>(Features);
3185b8dc7c8SMircea Trofin       while (InsPt < ToRead) {
31983051c5aSMircea Trofin         auto ReadOrErr = sys::fs::readNativeFile(
32083051c5aSMircea Trofin             FromCompiler, {Buff + InsPt, ToRead - InsPt});
32183051c5aSMircea Trofin         EXPECT_FALSE(ReadOrErr.takeError());
32283051c5aSMircea Trofin         InsPt += *ReadOrErr;
3235b8dc7c8SMircea Trofin       }
3245b8dc7c8SMircea Trofin     };
3255b8dc7c8SMircea Trofin     // Observation
3265b8dc7c8SMircea Trofin     EXPECT_FALSE(json::parse(ReadLn()).takeError());
3275b8dc7c8SMircea Trofin     // Tensor values
3285b8dc7c8SMircea Trofin     FullyRead();
3295b8dc7c8SMircea Trofin     // a "\n"
3305b8dc7c8SMircea Trofin     char Chr = 0;
33183051c5aSMircea Trofin     auto ReadNL = [&]() {
33283051c5aSMircea Trofin       do {
33383051c5aSMircea Trofin         auto ReadOrErr = sys::fs::readNativeFile(FromCompiler, {&Chr, 1});
33483051c5aSMircea Trofin         EXPECT_FALSE(ReadOrErr.takeError());
33583051c5aSMircea Trofin         if (*ReadOrErr == 1)
33683051c5aSMircea Trofin           break;
33783051c5aSMircea Trofin       } while (true);
33883051c5aSMircea Trofin     };
33983051c5aSMircea Trofin     ReadNL();
3405b8dc7c8SMircea Trofin     EXPECT_EQ(Chr, '\n');
3415b8dc7c8SMircea Trofin     EXPECT_EQ(Features[0], 42);
3425b8dc7c8SMircea Trofin     EXPECT_EQ(Features[1], 43);
3435b8dc7c8SMircea Trofin     EXPECT_EQ(Features[2], 100);
3445b8dc7c8SMircea Trofin     ++SeenObservations;
3455b8dc7c8SMircea Trofin 
3465b8dc7c8SMircea Trofin     // Send the advice
3475b8dc7c8SMircea Trofin     float Advice = 42.0012;
3485b8dc7c8SMircea Trofin     ToCompiler.write(reinterpret_cast<const char *>(&Advice),
3495b8dc7c8SMircea Trofin                      AdviceSpec.getTotalTensorBufferSize());
3505b8dc7c8SMircea Trofin     ToCompiler.flush();
3515b8dc7c8SMircea Trofin 
3525b8dc7c8SMircea Trofin     // Second observation, and same idea as above
3535b8dc7c8SMircea Trofin     EXPECT_FALSE(json::parse(ReadLn()).takeError());
3545b8dc7c8SMircea Trofin     FullyRead();
35583051c5aSMircea Trofin     ReadNL();
3565b8dc7c8SMircea Trofin     EXPECT_EQ(Chr, '\n');
3575b8dc7c8SMircea Trofin     EXPECT_EQ(Features[0], 10);
3585b8dc7c8SMircea Trofin     EXPECT_EQ(Features[1], -2);
3595b8dc7c8SMircea Trofin     EXPECT_EQ(Features[2], 1);
3605b8dc7c8SMircea Trofin     ++SeenObservations;
3615b8dc7c8SMircea Trofin     Advice = 50.30;
3625b8dc7c8SMircea Trofin     ToCompiler.write(reinterpret_cast<const char *>(&Advice),
3635b8dc7c8SMircea Trofin                      AdviceSpec.getTotalTensorBufferSize());
3645b8dc7c8SMircea Trofin     ToCompiler.flush();
36583051c5aSMircea Trofin     sys::fs::closeFile(FromCompiler);
3665b8dc7c8SMircea Trofin   });
3675b8dc7c8SMircea Trofin 
36883051c5aSMircea Trofin   InteractiveModelRunner Evaluator(Ctx, Inputs, AdviceSpec, FromCompilerName,
36983051c5aSMircea Trofin                                    ToCompilerName);
37083051c5aSMircea Trofin 
37183051c5aSMircea Trofin   Evaluator.switchContext("hi");
37283051c5aSMircea Trofin 
3735b8dc7c8SMircea Trofin   EXPECT_EQ(SeenObservations, 0);
3745b8dc7c8SMircea Trofin   *Evaluator.getTensor<int64_t>(0) = 42;
3755b8dc7c8SMircea Trofin   *Evaluator.getTensor<int64_t>(1) = 43;
3765b8dc7c8SMircea Trofin   *Evaluator.getTensor<int64_t>(2) = 100;
3775b8dc7c8SMircea Trofin   float Ret = Evaluator.evaluate<float>();
3785b8dc7c8SMircea Trofin   EXPECT_EQ(SeenObservations, 1);
3795b8dc7c8SMircea Trofin   EXPECT_FLOAT_EQ(Ret, 42.0012);
3805b8dc7c8SMircea Trofin 
3815b8dc7c8SMircea Trofin   *Evaluator.getTensor<int64_t>(0) = 10;
3825b8dc7c8SMircea Trofin   *Evaluator.getTensor<int64_t>(1) = -2;
3835b8dc7c8SMircea Trofin   *Evaluator.getTensor<int64_t>(2) = 1;
3845b8dc7c8SMircea Trofin   Ret = Evaluator.evaluate<float>();
3855b8dc7c8SMircea Trofin   EXPECT_EQ(SeenObservations, 2);
3865b8dc7c8SMircea Trofin   EXPECT_FLOAT_EQ(Ret, 50.30);
3875b8dc7c8SMircea Trofin   Advisor.join();
3885b8dc7c8SMircea Trofin }
38983051c5aSMircea Trofin #endif
390