xref: /llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp (revision c35ad9ee4f21c03baaea65e2479e9d08c4b4acd2)
1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10 // happens off a model that's provided from the command line and is interpreted.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/TensorSpec.h"
15 #include "llvm/Config/config.h"
16 #if defined(LLVM_HAVE_TF_API)
17 
18 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
19 
20 using namespace llvm;
21 
22 ModelUnderTrainingRunner::ModelUnderTrainingRunner(
23     LLVMContext &Ctx, const std::string &ModelPath,
24     const std::vector<TensorSpec> &InputSpecs,
25     const std::vector<LoggedFeatureSpec> &OutputSpecs)
26     : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
27       OutputSpecs(OutputSpecs) {
28   Evaluator = std::make_unique<TFModelEvaluator>(
29       ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
30       OutputSpecs.size());
31   if (!Evaluator || !Evaluator->isValid()) {
32     Ctx.emitError("Failed to create saved model evaluator");
33     Evaluator.reset();
34     return;
35   }
36 
37   for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
38     setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
39   }
40 }
41 
42 void *ModelUnderTrainingRunner::evaluateUntyped() {
43   LastEvaluationResult = Evaluator->evaluate();
44   if (!LastEvaluationResult.hasValue()) {
45     Ctx.emitError("Error evaluating model.");
46     return nullptr;
47   }
48   return LastEvaluationResult->getUntypedTensorValue(0);
49 }
50 
51 std::unique_ptr<ModelUnderTrainingRunner>
52 ModelUnderTrainingRunner::createAndEnsureValid(
53     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
54     const std::vector<TensorSpec> &InputSpecs,
55     StringRef OutputSpecsPathOverride) {
56   if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
57                                               OutputSpecsPathOverride))
58     return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs,
59                                 *MaybeOutputSpecs);
60   Ctx.emitError("Could not load the policy model from the provided path");
61   return nullptr;
62 }
63 
64 std::unique_ptr<ModelUnderTrainingRunner>
65 ModelUnderTrainingRunner::createAndEnsureValid(
66     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
67     const std::vector<TensorSpec> &InputSpecs,
68     const std::vector<LoggedFeatureSpec> &OutputSpecs) {
69   std::unique_ptr<ModelUnderTrainingRunner> MUTR;
70   MUTR.reset(
71       new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs));
72   if (MUTR && MUTR->isValid())
73     return MUTR;
74 
75   Ctx.emitError("Could not load or create model evaluator.");
76   return nullptr;
77 }
78 
79 #endif // defined(LLVM_HAVE_TF_API)
80