1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation 10 // happens off a model that's provided from the command line and is interpreted. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Config/config.h" 15 #if defined(LLVM_HAVE_TF_API) 16 17 #include "llvm/Analysis/ModelUnderTrainingRunner.h" 18 19 using namespace llvm; 20 21 ModelUnderTrainingRunner::ModelUnderTrainingRunner( 22 LLVMContext &Ctx, const std::string &ModelPath, 23 const std::vector<TensorSpec> &InputSpecs, 24 const std::vector<LoggedFeatureSpec> &OutputSpecs) 25 : MLModelRunner(Ctx), OutputSpecs(OutputSpecs) { 26 Evaluator = std::make_unique<TFModelEvaluator>( 27 ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, 28 OutputSpecs.size()); 29 if (!Evaluator || !Evaluator->isValid()) { 30 Ctx.emitError("Failed to create inliner saved model evaluator"); 31 Evaluator.reset(); 32 return; 33 } 34 } 35 36 void *ModelUnderTrainingRunner::evaluateUntyped() { 37 LastEvaluationResult = Evaluator->evaluate(); 38 if (!LastEvaluationResult.hasValue()) { 39 Ctx.emitError("Error evaluating model."); 40 return nullptr; 41 } 42 return LastEvaluationResult->getUntypedTensorValue(0); 43 } 44 45 void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) { 46 return Evaluator->getUntypedInput(Index); 47 } 48 49 #endif // defined(LLVM_HAVE_TF_API) 50