xref: /freebsd-src/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp (revision bdd1243df58e60e85101c09001d9812a789b6bc4)
10eae32dcSDimitry Andric //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
20eae32dcSDimitry Andric //
30eae32dcSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40eae32dcSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50eae32dcSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60eae32dcSDimitry Andric //
70eae32dcSDimitry Andric //===----------------------------------------------------------------------===//
80eae32dcSDimitry Andric //
90eae32dcSDimitry Andric // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
100eae32dcSDimitry Andric // happens off a model that's provided from the command line and is interpreted.
110eae32dcSDimitry Andric //
120eae32dcSDimitry Andric //===----------------------------------------------------------------------===//
130eae32dcSDimitry Andric 
14*bdd1243dSDimitry Andric #include "llvm/ADT/STLExtras.h"
150eae32dcSDimitry Andric #include "llvm/Config/config.h"
16*bdd1243dSDimitry Andric #if defined(LLVM_HAVE_TFLITE)
170eae32dcSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h"
18*bdd1243dSDimitry Andric #include "llvm/Support/MemoryBuffer.h"
19*bdd1243dSDimitry Andric #include "llvm/Support/Path.h"
20*bdd1243dSDimitry Andric #include <optional>
210eae32dcSDimitry Andric 
220eae32dcSDimitry Andric using namespace llvm;
23*bdd1243dSDimitry Andric namespace {
24*bdd1243dSDimitry Andric struct LoggedFeatureSpec {
25*bdd1243dSDimitry Andric   TensorSpec Spec;
26*bdd1243dSDimitry Andric   std::optional<std::string> LoggingName;
27*bdd1243dSDimitry Andric };
28*bdd1243dSDimitry Andric 
29*bdd1243dSDimitry Andric std::optional<std::vector<LoggedFeatureSpec>>
loadOutputSpecs(LLVMContext & Ctx,StringRef ExpectedDecisionName,StringRef ModelPath,StringRef SpecFileOverride)30*bdd1243dSDimitry Andric loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
31*bdd1243dSDimitry Andric                 StringRef ModelPath, StringRef SpecFileOverride) {
32*bdd1243dSDimitry Andric   SmallVector<char, 128> OutputSpecsPath;
33*bdd1243dSDimitry Andric   StringRef FileName = SpecFileOverride;
34*bdd1243dSDimitry Andric   if (FileName.empty()) {
35*bdd1243dSDimitry Andric     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
36*bdd1243dSDimitry Andric     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
37*bdd1243dSDimitry Andric   }
38*bdd1243dSDimitry Andric 
39*bdd1243dSDimitry Andric   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
40*bdd1243dSDimitry Andric   if (!BufferOrError) {
41*bdd1243dSDimitry Andric     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
42*bdd1243dSDimitry Andric                   BufferOrError.getError().message());
43*bdd1243dSDimitry Andric     return std::nullopt;
44*bdd1243dSDimitry Andric   }
45*bdd1243dSDimitry Andric   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
46*bdd1243dSDimitry Andric   if (!ParsedJSONValues) {
47*bdd1243dSDimitry Andric     Ctx.emitError("Could not parse specs file: " + FileName);
48*bdd1243dSDimitry Andric     return std::nullopt;
49*bdd1243dSDimitry Andric   }
50*bdd1243dSDimitry Andric   auto ValuesArray = ParsedJSONValues->getAsArray();
51*bdd1243dSDimitry Andric   if (!ValuesArray) {
52*bdd1243dSDimitry Andric     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
53*bdd1243dSDimitry Andric                   "logging_name:<name>} dictionaries");
54*bdd1243dSDimitry Andric     return std::nullopt;
55*bdd1243dSDimitry Andric   }
56*bdd1243dSDimitry Andric   std::vector<LoggedFeatureSpec> Ret;
57*bdd1243dSDimitry Andric   for (const auto &Value : *ValuesArray)
58*bdd1243dSDimitry Andric     if (const auto *Obj = Value.getAsObject())
59*bdd1243dSDimitry Andric       if (const auto *SpecPart = Obj->get("tensor_spec"))
60*bdd1243dSDimitry Andric         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
61*bdd1243dSDimitry Andric           if (auto LoggingName = Obj->getString("logging_name")) {
62*bdd1243dSDimitry Andric             if (!TensorSpec->isElementType<int64_t>() &&
63*bdd1243dSDimitry Andric                 !TensorSpec->isElementType<int32_t>() &&
64*bdd1243dSDimitry Andric                 !TensorSpec->isElementType<float>()) {
65*bdd1243dSDimitry Andric               Ctx.emitError(
66*bdd1243dSDimitry Andric                   "Only int64, int32, and float tensors are supported. "
67*bdd1243dSDimitry Andric                   "Found unsupported type for tensor named " +
68*bdd1243dSDimitry Andric                   TensorSpec->name());
69*bdd1243dSDimitry Andric               return std::nullopt;
70*bdd1243dSDimitry Andric             }
71*bdd1243dSDimitry Andric             Ret.push_back({*TensorSpec, LoggingName->str()});
72*bdd1243dSDimitry Andric           }
73*bdd1243dSDimitry Andric 
74*bdd1243dSDimitry Andric   if (ValuesArray->size() != Ret.size()) {
75*bdd1243dSDimitry Andric     Ctx.emitError(
76*bdd1243dSDimitry Andric         "Unable to parse output spec. It should be a json file containing an "
77*bdd1243dSDimitry Andric         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
78*bdd1243dSDimitry Andric         "with a json object describing a TensorSpec; and a 'logging_name' key, "
79*bdd1243dSDimitry Andric         "which is a string to use as name when logging this tensor in the "
80*bdd1243dSDimitry Andric         "training log.");
81*bdd1243dSDimitry Andric     return std::nullopt;
82*bdd1243dSDimitry Andric   }
83*bdd1243dSDimitry Andric   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
84*bdd1243dSDimitry Andric     Ctx.emitError("The first output spec must describe the decision tensor, "
85*bdd1243dSDimitry Andric                   "and must have the logging_name " +
86*bdd1243dSDimitry Andric                   StringRef(ExpectedDecisionName));
87*bdd1243dSDimitry Andric     return std::nullopt;
88*bdd1243dSDimitry Andric   }
89*bdd1243dSDimitry Andric   return Ret;
90*bdd1243dSDimitry Andric }
91*bdd1243dSDimitry Andric } // namespace
920eae32dcSDimitry Andric 
ModelUnderTrainingRunner(LLVMContext & Ctx,const std::string & ModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const std::vector<TensorSpec> & ExtraOutputsForLogging)930eae32dcSDimitry Andric ModelUnderTrainingRunner::ModelUnderTrainingRunner(
940eae32dcSDimitry Andric     LLVMContext &Ctx, const std::string &ModelPath,
950eae32dcSDimitry Andric     const std::vector<TensorSpec> &InputSpecs,
96*bdd1243dSDimitry Andric     const std::vector<TensorSpec> &OutputSpecs,
97*bdd1243dSDimitry Andric     const std::vector<TensorSpec> &ExtraOutputsForLogging)
9881ad6265SDimitry Andric     : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
99*bdd1243dSDimitry Andric       OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) {
100*bdd1243dSDimitry Andric   Evaluator =
101*bdd1243dSDimitry Andric       std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs);
1020eae32dcSDimitry Andric   if (!Evaluator || !Evaluator->isValid()) {
10304eeddc0SDimitry Andric     Ctx.emitError("Failed to create saved model evaluator");
1040eae32dcSDimitry Andric     Evaluator.reset();
1050eae32dcSDimitry Andric     return;
1060eae32dcSDimitry Andric   }
10781ad6265SDimitry Andric 
10881ad6265SDimitry Andric   for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
10981ad6265SDimitry Andric     setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
11081ad6265SDimitry Andric   }
1110eae32dcSDimitry Andric }
1120eae32dcSDimitry Andric 
evaluateUntyped()1130eae32dcSDimitry Andric void *ModelUnderTrainingRunner::evaluateUntyped() {
1140eae32dcSDimitry Andric   LastEvaluationResult = Evaluator->evaluate();
115*bdd1243dSDimitry Andric   if (!LastEvaluationResult.has_value()) {
1160eae32dcSDimitry Andric     Ctx.emitError("Error evaluating model.");
1170eae32dcSDimitry Andric     return nullptr;
1180eae32dcSDimitry Andric   }
1190eae32dcSDimitry Andric   return LastEvaluationResult->getUntypedTensorValue(0);
1200eae32dcSDimitry Andric }
1210eae32dcSDimitry Andric 
12281ad6265SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner>
createAndEnsureValid(LLVMContext & Ctx,const std::string & ModelPath,StringRef DecisionName,const std::vector<TensorSpec> & InputSpecs,StringRef OutputSpecsPathOverride)12381ad6265SDimitry Andric ModelUnderTrainingRunner::createAndEnsureValid(
12481ad6265SDimitry Andric     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
12581ad6265SDimitry Andric     const std::vector<TensorSpec> &InputSpecs,
12681ad6265SDimitry Andric     StringRef OutputSpecsPathOverride) {
12781ad6265SDimitry Andric   if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
128*bdd1243dSDimitry Andric                                               OutputSpecsPathOverride)) {
12904eeddc0SDimitry Andric     std::unique_ptr<ModelUnderTrainingRunner> MUTR;
130*bdd1243dSDimitry Andric     std::vector<TensorSpec> OutputSpecs;
131*bdd1243dSDimitry Andric     std::vector<TensorSpec> ExtraOutputsForLogging;
132*bdd1243dSDimitry Andric     append_range(OutputSpecs,
133*bdd1243dSDimitry Andric                  map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) {
134*bdd1243dSDimitry Andric                    return LFS.Spec;
135*bdd1243dSDimitry Andric                  }));
136*bdd1243dSDimitry Andric     append_range(ExtraOutputsForLogging,
137*bdd1243dSDimitry Andric                  map_range(drop_begin(*MaybeOutputSpecs),
138*bdd1243dSDimitry Andric                            [](const LoggedFeatureSpec &LFS) {
139*bdd1243dSDimitry Andric                              return TensorSpec(LFS.LoggingName
140*bdd1243dSDimitry Andric                                                    ? *LFS.LoggingName
141*bdd1243dSDimitry Andric                                                    : LFS.Spec.name(),
142*bdd1243dSDimitry Andric                                                LFS.Spec);
143*bdd1243dSDimitry Andric                            }));
144*bdd1243dSDimitry Andric 
145*bdd1243dSDimitry Andric     MUTR.reset(new ModelUnderTrainingRunner(
146*bdd1243dSDimitry Andric         Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging));
14704eeddc0SDimitry Andric     if (MUTR && MUTR->isValid())
14804eeddc0SDimitry Andric       return MUTR;
14904eeddc0SDimitry Andric 
15081ad6265SDimitry Andric     Ctx.emitError("Could not load or create model evaluator.");
15104eeddc0SDimitry Andric     return nullptr;
15204eeddc0SDimitry Andric   }
153*bdd1243dSDimitry Andric   Ctx.emitError("Could not load the policy model from the provided path");
154*bdd1243dSDimitry Andric   return nullptr;
155*bdd1243dSDimitry Andric }
15604eeddc0SDimitry Andric 
157*bdd1243dSDimitry Andric #endif // defined(LLVM_HAVE_TFLITE)
158