10eae32dcSDimitry Andric //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===// 20eae32dcSDimitry Andric // 30eae32dcSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40eae32dcSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 50eae32dcSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60eae32dcSDimitry Andric // 70eae32dcSDimitry Andric //===----------------------------------------------------------------------===// 80eae32dcSDimitry Andric // 90eae32dcSDimitry Andric // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation 100eae32dcSDimitry Andric // happens off a model that's provided from the command line and is interpreted. 110eae32dcSDimitry Andric // 120eae32dcSDimitry Andric //===----------------------------------------------------------------------===// 130eae32dcSDimitry Andric 140eae32dcSDimitry Andric #include "llvm/Config/config.h" 150eae32dcSDimitry Andric #if defined(LLVM_HAVE_TF_API) 160eae32dcSDimitry Andric 170eae32dcSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h" 180eae32dcSDimitry Andric 190eae32dcSDimitry Andric using namespace llvm; 200eae32dcSDimitry Andric 210eae32dcSDimitry Andric ModelUnderTrainingRunner::ModelUnderTrainingRunner( 220eae32dcSDimitry Andric LLVMContext &Ctx, const std::string &ModelPath, 230eae32dcSDimitry Andric const std::vector<TensorSpec> &InputSpecs, 240eae32dcSDimitry Andric const std::vector<LoggedFeatureSpec> &OutputSpecs) 25*04eeddc0SDimitry Andric : MLModelRunner(Ctx, MLModelRunner::Kind::Development), 26*04eeddc0SDimitry Andric OutputSpecs(OutputSpecs) { 270eae32dcSDimitry Andric Evaluator = std::make_unique<TFModelEvaluator>( 280eae32dcSDimitry Andric ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, 290eae32dcSDimitry Andric OutputSpecs.size()); 300eae32dcSDimitry Andric if (!Evaluator || !Evaluator->isValid()) { 31*04eeddc0SDimitry Andric Ctx.emitError("Failed to create saved model evaluator"); 320eae32dcSDimitry Andric Evaluator.reset(); 330eae32dcSDimitry Andric return; 340eae32dcSDimitry Andric } 350eae32dcSDimitry Andric } 360eae32dcSDimitry Andric 370eae32dcSDimitry Andric void *ModelUnderTrainingRunner::evaluateUntyped() { 380eae32dcSDimitry Andric LastEvaluationResult = Evaluator->evaluate(); 390eae32dcSDimitry Andric if (!LastEvaluationResult.hasValue()) { 400eae32dcSDimitry Andric Ctx.emitError("Error evaluating model."); 410eae32dcSDimitry Andric return nullptr; 420eae32dcSDimitry Andric } 430eae32dcSDimitry Andric return LastEvaluationResult->getUntypedTensorValue(0); 440eae32dcSDimitry Andric } 450eae32dcSDimitry Andric 460eae32dcSDimitry Andric void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) { 470eae32dcSDimitry Andric return Evaluator->getUntypedInput(Index); 480eae32dcSDimitry Andric } 490eae32dcSDimitry Andric 50*04eeddc0SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner> 51*04eeddc0SDimitry Andric ModelUnderTrainingRunner::createAndEnsureValid( 52*04eeddc0SDimitry Andric LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, 53*04eeddc0SDimitry Andric const std::vector<TensorSpec> &InputSpecs, 54*04eeddc0SDimitry Andric StringRef OutputSpecsPathOverride) { 55*04eeddc0SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner> MUTR; 56*04eeddc0SDimitry Andric if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath, 57*04eeddc0SDimitry Andric OutputSpecsPathOverride)) 58*04eeddc0SDimitry Andric MUTR.reset(new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, 59*04eeddc0SDimitry Andric *MaybeOutputSpecs)); 60*04eeddc0SDimitry Andric if (MUTR && MUTR->isValid()) 61*04eeddc0SDimitry Andric return MUTR; 62*04eeddc0SDimitry Andric 63*04eeddc0SDimitry Andric Ctx.emitError("Could not load the policy model from the provided path"); 64*04eeddc0SDimitry Andric return nullptr; 65*04eeddc0SDimitry Andric } 66*04eeddc0SDimitry Andric 670eae32dcSDimitry Andric #endif // defined(LLVM_HAVE_TF_API) 68