xref: /llvm-project/llvm/lib/Analysis/InteractiveModelRunner.cpp (revision 5b8dc7c8a55269aa438c6639a7ce22e6b99b1844)
1 //===- InteractiveModelRunner.cpp - noop ML model runner   ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A runner that communicates with an external agent via 2 file descriptors.
10 //===----------------------------------------------------------------------===//
11 #include "llvm/Analysis/InteractiveModelRunner.h"
12 #include "llvm/Analysis/MLModelRunner.h"
13 #include "llvm/Analysis/TensorSpec.h"
14 #include "llvm/Support/CommandLine.h"
15 #include "llvm/Support/ErrorHandling.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace llvm;
19 
20 #define _IMR_CL_VALS(T, N) clEnumValN(TensorType::N, #T, #T),
21 
22 static cl::opt<TensorType> DebugReply(
23     "interactive-model-runner-echo-type", cl::init(TensorType::Invalid),
24     cl::Hidden,
25     cl::desc("The InteractiveModelRunner will echo back to stderr "
26              "the data received "
27              "from the host as the specified type (for debugging purposes)."),
28     cl::values(SUPPORTED_TENSOR_TYPES(_IMR_CL_VALS)
29                    clEnumValN(TensorType::Invalid, "disable", "Don't echo")));
30 
31 #undef _IMR_CL_VALS
32 
33 InteractiveModelRunner::InteractiveModelRunner(
34     LLVMContext &Ctx, const std::vector<TensorSpec> &Inputs,
35     const TensorSpec &Advice, StringRef OutboundName, StringRef InboundName)
36     : MLModelRunner(Ctx, MLModelRunner::Kind::Interactive, Inputs.size()),
37       InputSpecs(Inputs), OutputSpec(Advice), Inbound(InboundName, InEC),
38       OutputBuffer(OutputSpec.getTotalTensorBufferSize()),
39       Log(std::make_unique<raw_fd_ostream>(OutboundName, OutEC), InputSpecs,
40           Advice, /*IncludeReward=*/false) {
41   if (InEC) {
42     Ctx.emitError("Cannot open inbound file: " + InEC.message());
43     return;
44   }
45   if (OutEC) {
46     Ctx.emitError("Cannot open outbound file: " + OutEC.message());
47     return;
48   }
49   // Just like in the no inference case, this will allocate an appropriately
50   // sized buffer.
51   for (size_t I = 0; I < InputSpecs.size(); ++I)
52     setUpBufferForTensor(I, InputSpecs[I], nullptr);
53   Log.flush();
54 }
55 
56 void *InteractiveModelRunner::evaluateUntyped() {
57   Log.startObservation();
58   for (size_t I = 0; I < InputSpecs.size(); ++I)
59     Log.logTensorValue(I, reinterpret_cast<const char *>(getTensorUntyped(I)));
60   Log.endObservation();
61   Log.flush();
62 
63   size_t InsPoint = 0;
64   char *Buff = OutputBuffer.data();
65   const size_t Limit = OutputBuffer.size();
66   while (InsPoint < Limit) {
67     auto Read = Inbound.read(Buff + InsPoint, OutputBuffer.size() - InsPoint);
68     if (Read < 0) {
69       Ctx.emitError("Failed reading from inbound file");
70       break;
71     }
72     InsPoint += Read;
73   }
74   if (DebugReply != TensorType::Invalid)
75     dbgs() << tensorValueToString(OutputBuffer.data(), OutputSpec);
76   return OutputBuffer.data();
77 }