xref: /llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp (revision d4b6fcb32e29d0cd834a3c89205fef48fbfc1d2d)
104f2712eSMircea Trofin //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
204f2712eSMircea Trofin //
304f2712eSMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
404f2712eSMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
504f2712eSMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
604f2712eSMircea Trofin //
704f2712eSMircea Trofin //===----------------------------------------------------------------------===//
804f2712eSMircea Trofin //
904f2712eSMircea Trofin // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
1004f2712eSMircea Trofin // happens off a model that's provided from the command line and is interpreted.
1104f2712eSMircea Trofin //
1204f2712eSMircea Trofin //===----------------------------------------------------------------------===//
1304f2712eSMircea Trofin 
141ee3bb17SMircea Trofin #include "llvm/ADT/STLExtras.h"
1504f2712eSMircea Trofin #include "llvm/Config/config.h"
16edc83a15SKazu Hirata #if defined(LLVM_HAVE_TFLITE)
1704f2712eSMircea Trofin #include "llvm/Analysis/ModelUnderTrainingRunner.h"
181ee3bb17SMircea Trofin #include "llvm/Support/MemoryBuffer.h"
191ee3bb17SMircea Trofin #include "llvm/Support/Path.h"
209c444f70SKazu Hirata #include <optional>
2104f2712eSMircea Trofin 
2204f2712eSMircea Trofin using namespace llvm;
231ee3bb17SMircea Trofin namespace {
241ee3bb17SMircea Trofin struct LoggedFeatureSpec {
251ee3bb17SMircea Trofin   TensorSpec Spec;
261ee3bb17SMircea Trofin   std::optional<std::string> LoggingName;
271ee3bb17SMircea Trofin };
281ee3bb17SMircea Trofin 
29*d4b6fcb3SFangrui Song std::optional<std::vector<LoggedFeatureSpec>>
loadOutputSpecs(LLVMContext & Ctx,StringRef ExpectedDecisionName,StringRef ModelPath,StringRef SpecFileOverride)301ee3bb17SMircea Trofin loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
311ee3bb17SMircea Trofin                 StringRef ModelPath, StringRef SpecFileOverride) {
321ee3bb17SMircea Trofin   SmallVector<char, 128> OutputSpecsPath;
331ee3bb17SMircea Trofin   StringRef FileName = SpecFileOverride;
341ee3bb17SMircea Trofin   if (FileName.empty()) {
351ee3bb17SMircea Trofin     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
361ee3bb17SMircea Trofin     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
371ee3bb17SMircea Trofin   }
381ee3bb17SMircea Trofin 
391ee3bb17SMircea Trofin   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
401ee3bb17SMircea Trofin   if (!BufferOrError) {
411ee3bb17SMircea Trofin     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
421ee3bb17SMircea Trofin                   BufferOrError.getError().message());
439c444f70SKazu Hirata     return std::nullopt;
441ee3bb17SMircea Trofin   }
451ee3bb17SMircea Trofin   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
461ee3bb17SMircea Trofin   if (!ParsedJSONValues) {
471ee3bb17SMircea Trofin     Ctx.emitError("Could not parse specs file: " + FileName);
489c444f70SKazu Hirata     return std::nullopt;
491ee3bb17SMircea Trofin   }
501ee3bb17SMircea Trofin   auto ValuesArray = ParsedJSONValues->getAsArray();
511ee3bb17SMircea Trofin   if (!ValuesArray) {
521ee3bb17SMircea Trofin     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
531ee3bb17SMircea Trofin                   "logging_name:<name>} dictionaries");
549c444f70SKazu Hirata     return std::nullopt;
551ee3bb17SMircea Trofin   }
561ee3bb17SMircea Trofin   std::vector<LoggedFeatureSpec> Ret;
571ee3bb17SMircea Trofin   for (const auto &Value : *ValuesArray)
581ee3bb17SMircea Trofin     if (const auto *Obj = Value.getAsObject())
591ee3bb17SMircea Trofin       if (const auto *SpecPart = Obj->get("tensor_spec"))
601ee3bb17SMircea Trofin         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
611ee3bb17SMircea Trofin           if (auto LoggingName = Obj->getString("logging_name")) {
621ee3bb17SMircea Trofin             if (!TensorSpec->isElementType<int64_t>() &&
631ee3bb17SMircea Trofin                 !TensorSpec->isElementType<int32_t>() &&
641ee3bb17SMircea Trofin                 !TensorSpec->isElementType<float>()) {
651ee3bb17SMircea Trofin               Ctx.emitError(
661ee3bb17SMircea Trofin                   "Only int64, int32, and float tensors are supported. "
671ee3bb17SMircea Trofin                   "Found unsupported type for tensor named " +
681ee3bb17SMircea Trofin                   TensorSpec->name());
699c444f70SKazu Hirata               return std::nullopt;
701ee3bb17SMircea Trofin             }
711ee3bb17SMircea Trofin             Ret.push_back({*TensorSpec, LoggingName->str()});
721ee3bb17SMircea Trofin           }
731ee3bb17SMircea Trofin 
741ee3bb17SMircea Trofin   if (ValuesArray->size() != Ret.size()) {
751ee3bb17SMircea Trofin     Ctx.emitError(
761ee3bb17SMircea Trofin         "Unable to parse output spec. It should be a json file containing an "
771ee3bb17SMircea Trofin         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
781ee3bb17SMircea Trofin         "with a json object describing a TensorSpec; and a 'logging_name' key, "
791ee3bb17SMircea Trofin         "which is a string to use as name when logging this tensor in the "
801ee3bb17SMircea Trofin         "training log.");
819c444f70SKazu Hirata     return std::nullopt;
821ee3bb17SMircea Trofin   }
831ee3bb17SMircea Trofin   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
841ee3bb17SMircea Trofin     Ctx.emitError("The first output spec must describe the decision tensor, "
851ee3bb17SMircea Trofin                   "and must have the logging_name " +
861ee3bb17SMircea Trofin                   StringRef(ExpectedDecisionName));
879c444f70SKazu Hirata     return std::nullopt;
881ee3bb17SMircea Trofin   }
891ee3bb17SMircea Trofin   return Ret;
901ee3bb17SMircea Trofin }
911ee3bb17SMircea Trofin } // namespace
9204f2712eSMircea Trofin 
ModelUnderTrainingRunner(LLVMContext & Ctx,const std::string & ModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const std::vector<TensorSpec> & ExtraOutputsForLogging)9304f2712eSMircea Trofin ModelUnderTrainingRunner::ModelUnderTrainingRunner(
9404f2712eSMircea Trofin     LLVMContext &Ctx, const std::string &ModelPath,
9504f2712eSMircea Trofin     const std::vector<TensorSpec> &InputSpecs,
961ee3bb17SMircea Trofin     const std::vector<TensorSpec> &OutputSpecs,
971ee3bb17SMircea Trofin     const std::vector<TensorSpec> &ExtraOutputsForLogging)
98c35ad9eeSMircea Trofin     : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
991ee3bb17SMircea Trofin       OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) {
1001ee3bb17SMircea Trofin   Evaluator =
1011ee3bb17SMircea Trofin       std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs);
10204f2712eSMircea Trofin   if (!Evaluator || !Evaluator->isValid()) {
103a81b0c97SMircea Trofin     Ctx.emitError("Failed to create saved model evaluator");
10404f2712eSMircea Trofin     Evaluator.reset();
10504f2712eSMircea Trofin     return;
10604f2712eSMircea Trofin   }
107c35ad9eeSMircea Trofin 
108c35ad9eeSMircea Trofin   for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
109c35ad9eeSMircea Trofin     setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
110c35ad9eeSMircea Trofin   }
11104f2712eSMircea Trofin }
11204f2712eSMircea Trofin 
evaluateUntyped()11304f2712eSMircea Trofin void *ModelUnderTrainingRunner::evaluateUntyped() {
11404f2712eSMircea Trofin   LastEvaluationResult = Evaluator->evaluate();
11565abca46SAiden Grossman   if (!LastEvaluationResult.has_value()) {
11604f2712eSMircea Trofin     Ctx.emitError("Error evaluating model.");
11704f2712eSMircea Trofin     return nullptr;
11804f2712eSMircea Trofin   }
11904f2712eSMircea Trofin   return LastEvaluationResult->getUntypedTensorValue(0);
12004f2712eSMircea Trofin }
12104f2712eSMircea Trofin 
122c35ad9eeSMircea Trofin std::unique_ptr<ModelUnderTrainingRunner>
createAndEnsureValid(LLVMContext & Ctx,const std::string & ModelPath,StringRef DecisionName,const std::vector<TensorSpec> & InputSpecs,StringRef OutputSpecsPathOverride)123c35ad9eeSMircea Trofin ModelUnderTrainingRunner::createAndEnsureValid(
124c35ad9eeSMircea Trofin     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
125c35ad9eeSMircea Trofin     const std::vector<TensorSpec> &InputSpecs,
126c35ad9eeSMircea Trofin     StringRef OutputSpecsPathOverride) {
127c35ad9eeSMircea Trofin   if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
1281ee3bb17SMircea Trofin                                               OutputSpecsPathOverride)) {
129a120fdd3SMircea Trofin     std::unique_ptr<ModelUnderTrainingRunner> MUTR;
1301ee3bb17SMircea Trofin     std::vector<TensorSpec> OutputSpecs;
1311ee3bb17SMircea Trofin     std::vector<TensorSpec> ExtraOutputsForLogging;
1321ee3bb17SMircea Trofin     append_range(OutputSpecs,
1331ee3bb17SMircea Trofin                  map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) {
1341ee3bb17SMircea Trofin                    return LFS.Spec;
1351ee3bb17SMircea Trofin                  }));
1361ee3bb17SMircea Trofin     append_range(ExtraOutputsForLogging,
1371ee3bb17SMircea Trofin                  map_range(drop_begin(*MaybeOutputSpecs),
1381ee3bb17SMircea Trofin                            [](const LoggedFeatureSpec &LFS) {
1391ee3bb17SMircea Trofin                              return TensorSpec(LFS.LoggingName
1401ee3bb17SMircea Trofin                                                    ? *LFS.LoggingName
1411ee3bb17SMircea Trofin                                                    : LFS.Spec.name(),
1421ee3bb17SMircea Trofin                                                LFS.Spec);
1431ee3bb17SMircea Trofin                            }));
1441ee3bb17SMircea Trofin 
1451ee3bb17SMircea Trofin     MUTR.reset(new ModelUnderTrainingRunner(
1461ee3bb17SMircea Trofin         Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging));
147a120fdd3SMircea Trofin     if (MUTR && MUTR->isValid())
148a120fdd3SMircea Trofin       return MUTR;
149a120fdd3SMircea Trofin 
150c35ad9eeSMircea Trofin     Ctx.emitError("Could not load or create model evaluator.");
151a120fdd3SMircea Trofin     return nullptr;
152a120fdd3SMircea Trofin   }
1531ee3bb17SMircea Trofin   Ctx.emitError("Could not load the policy model from the provided path");
1541ee3bb17SMircea Trofin   return nullptr;
1551ee3bb17SMircea Trofin }
156a120fdd3SMircea Trofin 
157edc83a15SKazu Hirata #endif // defined(LLVM_HAVE_TFLITE)
158