xref: /llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp (revision 1ee3bb17c39579de21ea0bd526e79bb932b8b1c3)
1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10 // happens off a model that's provided from the command line and is interpreted.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Config/config.h"
16 #if defined(LLVM_HAVE_TF_API)
17 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/Path.h"
20 
21 using namespace llvm;
22 namespace {
23 struct LoggedFeatureSpec {
24   TensorSpec Spec;
25   std::optional<std::string> LoggingName;
26 };
27 
28 Optional<std::vector<LoggedFeatureSpec>>
29 loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
30                 StringRef ModelPath, StringRef SpecFileOverride) {
31   SmallVector<char, 128> OutputSpecsPath;
32   StringRef FileName = SpecFileOverride;
33   if (FileName.empty()) {
34     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
35     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
36   }
37 
38   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
39   if (!BufferOrError) {
40     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
41                   BufferOrError.getError().message());
42     return None;
43   }
44   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
45   if (!ParsedJSONValues) {
46     Ctx.emitError("Could not parse specs file: " + FileName);
47     return None;
48   }
49   auto ValuesArray = ParsedJSONValues->getAsArray();
50   if (!ValuesArray) {
51     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
52                   "logging_name:<name>} dictionaries");
53     return None;
54   }
55   std::vector<LoggedFeatureSpec> Ret;
56   for (const auto &Value : *ValuesArray)
57     if (const auto *Obj = Value.getAsObject())
58       if (const auto *SpecPart = Obj->get("tensor_spec"))
59         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
60           if (auto LoggingName = Obj->getString("logging_name")) {
61             if (!TensorSpec->isElementType<int64_t>() &&
62                 !TensorSpec->isElementType<int32_t>() &&
63                 !TensorSpec->isElementType<float>()) {
64               Ctx.emitError(
65                   "Only int64, int32, and float tensors are supported. "
66                   "Found unsupported type for tensor named " +
67                   TensorSpec->name());
68               return None;
69             }
70             Ret.push_back({*TensorSpec, LoggingName->str()});
71           }
72 
73   if (ValuesArray->size() != Ret.size()) {
74     Ctx.emitError(
75         "Unable to parse output spec. It should be a json file containing an "
76         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
77         "with a json object describing a TensorSpec; and a 'logging_name' key, "
78         "which is a string to use as name when logging this tensor in the "
79         "training log.");
80     return None;
81   }
82   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
83     Ctx.emitError("The first output spec must describe the decision tensor, "
84                   "and must have the logging_name " +
85                   StringRef(ExpectedDecisionName));
86     return None;
87   }
88   return Ret;
89 }
90 } // namespace
91 
92 ModelUnderTrainingRunner::ModelUnderTrainingRunner(
93     LLVMContext &Ctx, const std::string &ModelPath,
94     const std::vector<TensorSpec> &InputSpecs,
95     const std::vector<TensorSpec> &OutputSpecs,
96     const std::vector<TensorSpec> &ExtraOutputsForLogging)
97     : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
98       OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) {
99   Evaluator =
100       std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs);
101   if (!Evaluator || !Evaluator->isValid()) {
102     Ctx.emitError("Failed to create saved model evaluator");
103     Evaluator.reset();
104     return;
105   }
106 
107   for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
108     setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
109   }
110 }
111 
112 void *ModelUnderTrainingRunner::evaluateUntyped() {
113   LastEvaluationResult = Evaluator->evaluate();
114   if (!LastEvaluationResult.has_value()) {
115     Ctx.emitError("Error evaluating model.");
116     return nullptr;
117   }
118   return LastEvaluationResult->getUntypedTensorValue(0);
119 }
120 
121 std::unique_ptr<ModelUnderTrainingRunner>
122 ModelUnderTrainingRunner::createAndEnsureValid(
123     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
124     const std::vector<TensorSpec> &InputSpecs,
125     StringRef OutputSpecsPathOverride) {
126   if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
127                                               OutputSpecsPathOverride)) {
128     std::unique_ptr<ModelUnderTrainingRunner> MUTR;
129     std::vector<TensorSpec> OutputSpecs;
130     std::vector<TensorSpec> ExtraOutputsForLogging;
131     append_range(OutputSpecs,
132                  map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) {
133                    return LFS.Spec;
134                  }));
135     append_range(ExtraOutputsForLogging,
136                  map_range(drop_begin(*MaybeOutputSpecs),
137                            [](const LoggedFeatureSpec &LFS) {
138                              return TensorSpec(LFS.LoggingName
139                                                    ? *LFS.LoggingName
140                                                    : LFS.Spec.name(),
141                                                LFS.Spec);
142                            }));
143 
144     MUTR.reset(new ModelUnderTrainingRunner(
145         Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging));
146     if (MUTR && MUTR->isValid())
147       return MUTR;
148 
149     Ctx.emitError("Could not load or create model evaluator.");
150     return nullptr;
151   }
152   Ctx.emitError("Could not load the policy model from the provided path");
153   return nullptr;
154 }
155 
156 #endif // defined(LLVM_HAVE_TF_API)
157