xref: /freebsd-src/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp (revision 0eae32dcef82f6f06de6419a0d623d7def0cc8f6)
1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10 // happens off a model that's provided from the command line and is interpreted.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Config/config.h"
15 #if defined(LLVM_HAVE_TF_API)
16 
17 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
18 
19 using namespace llvm;
20 
21 ModelUnderTrainingRunner::ModelUnderTrainingRunner(
22     LLVMContext &Ctx, const std::string &ModelPath,
23     const std::vector<TensorSpec> &InputSpecs,
24     const std::vector<LoggedFeatureSpec> &OutputSpecs)
25     : MLModelRunner(Ctx), OutputSpecs(OutputSpecs) {
26   Evaluator = std::make_unique<TFModelEvaluator>(
27       ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
28       OutputSpecs.size());
29   if (!Evaluator || !Evaluator->isValid()) {
30     Ctx.emitError("Failed to create inliner saved model evaluator");
31     Evaluator.reset();
32     return;
33   }
34 }
35 
36 void *ModelUnderTrainingRunner::evaluateUntyped() {
37   LastEvaluationResult = Evaluator->evaluate();
38   if (!LastEvaluationResult.hasValue()) {
39     Ctx.emitError("Error evaluating model.");
40     return nullptr;
41   }
42   return LastEvaluationResult->getUntypedTensorValue(0);
43 }
44 
45 void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) {
46   return Evaluator->getUntypedInput(Index);
47 }
48 
49 #endif // defined(LLVM_HAVE_TF_API)
50