xref: /freebsd-src/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp (revision 04eeddc0aa8e0a417a16eaf9d7d095207f4a8623)
10eae32dcSDimitry Andric //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
20eae32dcSDimitry Andric //
30eae32dcSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40eae32dcSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50eae32dcSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60eae32dcSDimitry Andric //
70eae32dcSDimitry Andric //===----------------------------------------------------------------------===//
80eae32dcSDimitry Andric //
90eae32dcSDimitry Andric // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
100eae32dcSDimitry Andric // happens off a model that's provided from the command line and is interpreted.
110eae32dcSDimitry Andric //
120eae32dcSDimitry Andric //===----------------------------------------------------------------------===//
130eae32dcSDimitry Andric 
140eae32dcSDimitry Andric #include "llvm/Config/config.h"
150eae32dcSDimitry Andric #if defined(LLVM_HAVE_TF_API)
160eae32dcSDimitry Andric 
170eae32dcSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h"
180eae32dcSDimitry Andric 
190eae32dcSDimitry Andric using namespace llvm;
200eae32dcSDimitry Andric 
210eae32dcSDimitry Andric ModelUnderTrainingRunner::ModelUnderTrainingRunner(
220eae32dcSDimitry Andric     LLVMContext &Ctx, const std::string &ModelPath,
230eae32dcSDimitry Andric     const std::vector<TensorSpec> &InputSpecs,
240eae32dcSDimitry Andric     const std::vector<LoggedFeatureSpec> &OutputSpecs)
25*04eeddc0SDimitry Andric     : MLModelRunner(Ctx, MLModelRunner::Kind::Development),
26*04eeddc0SDimitry Andric       OutputSpecs(OutputSpecs) {
270eae32dcSDimitry Andric   Evaluator = std::make_unique<TFModelEvaluator>(
280eae32dcSDimitry Andric       ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
290eae32dcSDimitry Andric       OutputSpecs.size());
300eae32dcSDimitry Andric   if (!Evaluator || !Evaluator->isValid()) {
31*04eeddc0SDimitry Andric     Ctx.emitError("Failed to create saved model evaluator");
320eae32dcSDimitry Andric     Evaluator.reset();
330eae32dcSDimitry Andric     return;
340eae32dcSDimitry Andric   }
350eae32dcSDimitry Andric }
360eae32dcSDimitry Andric 
370eae32dcSDimitry Andric void *ModelUnderTrainingRunner::evaluateUntyped() {
380eae32dcSDimitry Andric   LastEvaluationResult = Evaluator->evaluate();
390eae32dcSDimitry Andric   if (!LastEvaluationResult.hasValue()) {
400eae32dcSDimitry Andric     Ctx.emitError("Error evaluating model.");
410eae32dcSDimitry Andric     return nullptr;
420eae32dcSDimitry Andric   }
430eae32dcSDimitry Andric   return LastEvaluationResult->getUntypedTensorValue(0);
440eae32dcSDimitry Andric }
450eae32dcSDimitry Andric 
460eae32dcSDimitry Andric void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) {
470eae32dcSDimitry Andric   return Evaluator->getUntypedInput(Index);
480eae32dcSDimitry Andric }
490eae32dcSDimitry Andric 
50*04eeddc0SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner>
51*04eeddc0SDimitry Andric ModelUnderTrainingRunner::createAndEnsureValid(
52*04eeddc0SDimitry Andric     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
53*04eeddc0SDimitry Andric     const std::vector<TensorSpec> &InputSpecs,
54*04eeddc0SDimitry Andric     StringRef OutputSpecsPathOverride) {
55*04eeddc0SDimitry Andric   std::unique_ptr<ModelUnderTrainingRunner> MUTR;
56*04eeddc0SDimitry Andric   if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
57*04eeddc0SDimitry Andric                                               OutputSpecsPathOverride))
58*04eeddc0SDimitry Andric     MUTR.reset(new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs,
59*04eeddc0SDimitry Andric                                             *MaybeOutputSpecs));
60*04eeddc0SDimitry Andric   if (MUTR && MUTR->isValid())
61*04eeddc0SDimitry Andric     return MUTR;
62*04eeddc0SDimitry Andric 
63*04eeddc0SDimitry Andric   Ctx.emitError("Could not load the policy model from the provided path");
64*04eeddc0SDimitry Andric   return nullptr;
65*04eeddc0SDimitry Andric }
66*04eeddc0SDimitry Andric 
670eae32dcSDimitry Andric #endif // defined(LLVM_HAVE_TF_API)
68