1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation 10 // happens off a model that's provided from the command line and is interpreted. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/TensorSpec.h" 15 #include "llvm/Config/config.h" 16 #if defined(LLVM_HAVE_TF_API) 17 18 #include "llvm/Analysis/ModelUnderTrainingRunner.h" 19 20 using namespace llvm; 21 22 ModelUnderTrainingRunner::ModelUnderTrainingRunner( 23 LLVMContext &Ctx, const std::string &ModelPath, 24 const std::vector<TensorSpec> &InputSpecs, 25 const std::vector<LoggedFeatureSpec> &OutputSpecs) 26 : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()), 27 OutputSpecs(OutputSpecs) { 28 Evaluator = std::make_unique<TFModelEvaluator>( 29 ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, 30 OutputSpecs.size()); 31 if (!Evaluator || !Evaluator->isValid()) { 32 Ctx.emitError("Failed to create saved model evaluator"); 33 Evaluator.reset(); 34 return; 35 } 36 37 for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) { 38 setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I)); 39 } 40 } 41 42 void *ModelUnderTrainingRunner::evaluateUntyped() { 43 LastEvaluationResult = Evaluator->evaluate(); 44 if (!LastEvaluationResult.hasValue()) { 45 Ctx.emitError("Error evaluating model."); 46 return nullptr; 47 } 48 return LastEvaluationResult->getUntypedTensorValue(0); 49 } 50 51 std::unique_ptr<ModelUnderTrainingRunner> 52 ModelUnderTrainingRunner::createAndEnsureValid( 53 LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, 54 const std::vector<TensorSpec> &InputSpecs, 55 StringRef OutputSpecsPathOverride) { 56 if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath, 57 OutputSpecsPathOverride)) 58 return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs, 59 *MaybeOutputSpecs); 60 Ctx.emitError("Could not load the policy model from the provided path"); 61 return nullptr; 62 } 63 64 std::unique_ptr<ModelUnderTrainingRunner> 65 ModelUnderTrainingRunner::createAndEnsureValid( 66 LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, 67 const std::vector<TensorSpec> &InputSpecs, 68 const std::vector<LoggedFeatureSpec> &OutputSpecs) { 69 std::unique_ptr<ModelUnderTrainingRunner> MUTR; 70 MUTR.reset( 71 new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs)); 72 if (MUTR && MUTR->isValid()) 73 return MUTR; 74 75 Ctx.emitError("Could not load or create model evaluator."); 76 return nullptr; 77 } 78 79 #endif // defined(LLVM_HAVE_TF_API) 80