xref: /llvm-project/llvm/lib/Analysis/TFLiteUtils.cpp (revision 28bb2193f6d3bb52f6bba9c64e392fe6c8be0f88)
1*28bb2193SMircea Trofin //===- TFUtils.cpp - TFLite-based evaluation utilities --------------------===//
25ce4c9aaSMircea Trofin //
35ce4c9aaSMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45ce4c9aaSMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
55ce4c9aaSMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65ce4c9aaSMircea Trofin //
75ce4c9aaSMircea Trofin //===----------------------------------------------------------------------===//
85ce4c9aaSMircea Trofin //
9*28bb2193SMircea Trofin // This file implements utilities for interfacing with TFLite.
105ce4c9aaSMircea Trofin //
115ce4c9aaSMircea Trofin //===----------------------------------------------------------------------===//
125ce4c9aaSMircea Trofin #include "llvm/Config/config.h"
135ce4c9aaSMircea Trofin #if defined(LLVM_HAVE_TFLITE)
145ce4c9aaSMircea Trofin 
155ce4c9aaSMircea Trofin #include "llvm/ADT/Twine.h"
165ce4c9aaSMircea Trofin #include "llvm/Analysis/Utils/TFUtils.h"
175ce4c9aaSMircea Trofin #include "llvm/Support/Base64.h"
185ce4c9aaSMircea Trofin #include "llvm/Support/CommandLine.h"
195ce4c9aaSMircea Trofin #include "llvm/Support/Debug.h"
205ce4c9aaSMircea Trofin #include "llvm/Support/JSON.h"
215ce4c9aaSMircea Trofin #include "llvm/Support/MemoryBuffer.h"
225ce4c9aaSMircea Trofin #include "llvm/Support/Path.h"
235ce4c9aaSMircea Trofin #include "llvm/Support/raw_ostream.h"
245ce4c9aaSMircea Trofin 
255ce4c9aaSMircea Trofin #include "tensorflow/lite/interpreter.h"
265ce4c9aaSMircea Trofin #include "tensorflow/lite/kernels/register.h"
275ce4c9aaSMircea Trofin #include "tensorflow/lite/model.h"
285ce4c9aaSMircea Trofin #include "tensorflow/lite/model_builder.h"
295ce4c9aaSMircea Trofin #include "tensorflow/lite/op_resolver.h"
30a219a8a8SMircea Trofin #include "tensorflow/lite/logger.h"
315ce4c9aaSMircea Trofin 
325ce4c9aaSMircea Trofin #include <cassert>
335ce4c9aaSMircea Trofin #include <numeric>
349c444f70SKazu Hirata #include <optional>
355ce4c9aaSMircea Trofin 
365ce4c9aaSMircea Trofin using namespace llvm;
375ce4c9aaSMircea Trofin 
385ce4c9aaSMircea Trofin namespace llvm {
395ce4c9aaSMircea Trofin class EvaluationResultImpl {
405ce4c9aaSMircea Trofin public:
EvaluationResultImpl(const std::vector<const TfLiteTensor * > & Outputs)415ce4c9aaSMircea Trofin   EvaluationResultImpl(const std::vector<const TfLiteTensor *> &Outputs)
425ce4c9aaSMircea Trofin       : Outputs(Outputs){};
435ce4c9aaSMircea Trofin 
getOutput(size_t I)445ce4c9aaSMircea Trofin   const TfLiteTensor *getOutput(size_t I) { return Outputs[I]; }
455ce4c9aaSMircea Trofin 
465ce4c9aaSMircea Trofin   EvaluationResultImpl(const EvaluationResultImpl &) = delete;
475ce4c9aaSMircea Trofin   EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
485ce4c9aaSMircea Trofin 
495ce4c9aaSMircea Trofin private:
505ce4c9aaSMircea Trofin   const std::vector<const TfLiteTensor *> Outputs;
515ce4c9aaSMircea Trofin };
525ce4c9aaSMircea Trofin 
535ce4c9aaSMircea Trofin class TFModelEvaluatorImpl {
545ce4c9aaSMircea Trofin public:
555ce4c9aaSMircea Trofin   TFModelEvaluatorImpl(StringRef SavedModelPath,
565ce4c9aaSMircea Trofin                        const std::vector<TensorSpec> &InputSpecs,
571ee3bb17SMircea Trofin                        const std::vector<TensorSpec> &OutputSpecs,
581ee3bb17SMircea Trofin                        const char *Tags);
595ce4c9aaSMircea Trofin 
isValid() const605ce4c9aaSMircea Trofin   bool isValid() const { return IsValid; }
outputSize() const615ce4c9aaSMircea Trofin   size_t outputSize() const { return Output.size(); }
625ce4c9aaSMircea Trofin 
evaluate()635ce4c9aaSMircea Trofin   std::unique_ptr<EvaluationResultImpl> evaluate() {
645ce4c9aaSMircea Trofin     Interpreter->Invoke();
655ce4c9aaSMircea Trofin     return std::make_unique<EvaluationResultImpl>(Output);
665ce4c9aaSMircea Trofin   }
675ce4c9aaSMircea Trofin 
getInput() const685ce4c9aaSMircea Trofin   const std::vector<TfLiteTensor *> &getInput() const { return Input; }
695ce4c9aaSMircea Trofin 
705ce4c9aaSMircea Trofin   ~TFModelEvaluatorImpl();
715ce4c9aaSMircea Trofin 
725ce4c9aaSMircea Trofin private:
735ce4c9aaSMircea Trofin   std::unique_ptr<tflite::FlatBufferModel> Model;
745ce4c9aaSMircea Trofin 
755ce4c9aaSMircea Trofin   /// The objects necessary for carrying out an evaluation of the SavedModel.
765ce4c9aaSMircea Trofin   /// They are expensive to set up, and we maintain them accross all the
775ce4c9aaSMircea Trofin   /// evaluations of the model.
785ce4c9aaSMircea Trofin   std::unique_ptr<tflite::Interpreter> Interpreter;
795ce4c9aaSMircea Trofin 
805ce4c9aaSMircea Trofin   /// The input tensors. We set up the tensors once and just mutate theirs
815ce4c9aaSMircea Trofin   /// scalars before each evaluation. The input tensors keep their value after
825ce4c9aaSMircea Trofin   /// an evaluation.
835ce4c9aaSMircea Trofin   std::vector<TfLiteTensor *> Input;
845ce4c9aaSMircea Trofin 
855ce4c9aaSMircea Trofin   /// The output nodes.
865ce4c9aaSMircea Trofin   std::vector<const TfLiteTensor *> Output;
875ce4c9aaSMircea Trofin 
invalidate()885ce4c9aaSMircea Trofin   void invalidate() { IsValid = false; }
895ce4c9aaSMircea Trofin 
905ce4c9aaSMircea Trofin   bool IsValid = true;
915ce4c9aaSMircea Trofin 
925ce4c9aaSMircea Trofin   /// Reusable utility for ensuring we can bind the requested Name to a node in
935ce4c9aaSMircea Trofin   /// the SavedModel Graph.
945ce4c9aaSMircea Trofin   bool checkReportAndInvalidate(const TfLiteTensor *Tensor,
955ce4c9aaSMircea Trofin                                 const TensorSpec &Spec);
965ce4c9aaSMircea Trofin };
975ce4c9aaSMircea Trofin 
985ce4c9aaSMircea Trofin } // namespace llvm
995ce4c9aaSMircea Trofin 
TFModelEvaluatorImpl(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const char * Tags="serve")1005ce4c9aaSMircea Trofin TFModelEvaluatorImpl::TFModelEvaluatorImpl(
1015ce4c9aaSMircea Trofin     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
1021ee3bb17SMircea Trofin     const std::vector<TensorSpec> &OutputSpecs, const char *Tags = "serve")
1031ee3bb17SMircea Trofin     : Input(InputSpecs.size()), Output(OutputSpecs.size()) {
104a219a8a8SMircea Trofin   // INFO and DEBUG messages could be numerous and not particularly interesting
105a219a8a8SMircea Trofin   tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_WARNING);
1065ce4c9aaSMircea Trofin   // FIXME: make ErrorReporter a member (may also need subclassing
1075ce4c9aaSMircea Trofin   // StatefulErrorReporter) to easily get the latest error status, for
1085ce4c9aaSMircea Trofin   // debugging.
1095ce4c9aaSMircea Trofin   tflite::StderrReporter ErrorReporter;
1105ce4c9aaSMircea Trofin   SmallVector<char, 128> TFLitePathBuff;
1115ce4c9aaSMircea Trofin   llvm::sys::path::append(TFLitePathBuff, SavedModelPath, "model.tflite");
1125ce4c9aaSMircea Trofin   StringRef TFLitePath(TFLitePathBuff.data(), TFLitePathBuff.size());
1135ce4c9aaSMircea Trofin   Model = tflite::FlatBufferModel::BuildFromFile(TFLitePath.str().c_str(),
1145ce4c9aaSMircea Trofin                                                  &ErrorReporter);
1155ce4c9aaSMircea Trofin   if (!Model) {
1165ce4c9aaSMircea Trofin     invalidate();
1175ce4c9aaSMircea Trofin     return;
1185ce4c9aaSMircea Trofin   }
1195ce4c9aaSMircea Trofin 
1205ce4c9aaSMircea Trofin   tflite::ops::builtin::BuiltinOpResolver Resolver;
1215ce4c9aaSMircea Trofin   tflite::InterpreterBuilder Builder(*Model, Resolver);
1225ce4c9aaSMircea Trofin   Builder(&Interpreter);
1235ce4c9aaSMircea Trofin 
12417095dfeSJacob Hegna   if (!Interpreter) {
12517095dfeSJacob Hegna     invalidate();
12617095dfeSJacob Hegna     return;
12717095dfeSJacob Hegna   }
12817095dfeSJacob Hegna 
1299d93a98fSJacob Hegna   // We assume the input buffers are valid for the lifetime of the interpreter.
1309d93a98fSJacob Hegna   // By default, tflite allocates memory in an arena and will periodically take
1319d93a98fSJacob Hegna   // away memory and reallocate it in a different location after evaluations in
1329d93a98fSJacob Hegna   // order to improve utilization of the buffers owned in the arena. So, we
1339d93a98fSJacob Hegna   // explicitly mark our input buffers as persistent to avoid this behavior.
1349d93a98fSJacob Hegna   for (size_t I = 0; I < Interpreter->inputs().size(); ++I)
1359d93a98fSJacob Hegna     Interpreter->tensor(I)->allocation_type =
1369d93a98fSJacob Hegna         TfLiteAllocationType::kTfLiteArenaRwPersistent;
1379d93a98fSJacob Hegna 
13817095dfeSJacob Hegna   if (Interpreter->AllocateTensors() != TfLiteStatus::kTfLiteOk) {
1395ce4c9aaSMircea Trofin     invalidate();
1405ce4c9aaSMircea Trofin     return;
1415ce4c9aaSMircea Trofin   }
1425ce4c9aaSMircea Trofin   // Known inputs and outputs
1435ce4c9aaSMircea Trofin   StringMap<int> InputsMap;
1445ce4c9aaSMircea Trofin   StringMap<int> OutputsMap;
1455ce4c9aaSMircea Trofin   for (size_t I = 0; I < Interpreter->inputs().size(); ++I)
1465ce4c9aaSMircea Trofin     InputsMap[Interpreter->GetInputName(I)] = I;
1475ce4c9aaSMircea Trofin   for (size_t I = 0; I < Interpreter->outputs().size(); ++I)
1485ce4c9aaSMircea Trofin     OutputsMap[Interpreter->GetOutputName(I)] = I;
1495ce4c9aaSMircea Trofin 
150ec83c7e3SAiden Grossman   size_t NumberFeaturesPassed = 0;
1515ce4c9aaSMircea Trofin   for (size_t I = 0; I < InputSpecs.size(); ++I) {
1525ce4c9aaSMircea Trofin     auto &InputSpec = InputSpecs[I];
1535ce4c9aaSMircea Trofin     auto MapI = InputsMap.find(InputSpec.name() + ":" +
1545ce4c9aaSMircea Trofin                                std::to_string(InputSpec.port()));
1555ce4c9aaSMircea Trofin     if (MapI == InputsMap.end()) {
1565ce4c9aaSMircea Trofin       Input[I] = nullptr;
1575ce4c9aaSMircea Trofin       continue;
1585ce4c9aaSMircea Trofin     }
1595ce4c9aaSMircea Trofin     Input[I] = Interpreter->tensor(MapI->second);
1605ce4c9aaSMircea Trofin     if (!checkReportAndInvalidate(Input[I], InputSpec))
1615ce4c9aaSMircea Trofin       return;
1625ce4c9aaSMircea Trofin     std::memset(Input[I]->data.data, 0,
1635ce4c9aaSMircea Trofin                 InputSpecs[I].getTotalTensorBufferSize());
164ec83c7e3SAiden Grossman     ++NumberFeaturesPassed;
165ec83c7e3SAiden Grossman   }
166ec83c7e3SAiden Grossman 
167ec83c7e3SAiden Grossman   if (NumberFeaturesPassed < Interpreter->inputs().size()) {
168ec83c7e3SAiden Grossman     // we haven't passed all the required features to the model, throw an error.
169ec83c7e3SAiden Grossman     errs() << "Required feature(s) have not been passed to the ML model";
170ec83c7e3SAiden Grossman     invalidate();
171ec83c7e3SAiden Grossman     return;
1725ce4c9aaSMircea Trofin   }
1735ce4c9aaSMircea Trofin 
1741ee3bb17SMircea Trofin   for (size_t I = 0; I < OutputSpecs.size(); ++I) {
1751ee3bb17SMircea Trofin     const auto &OutputSpec = OutputSpecs[I];
1765ce4c9aaSMircea Trofin     Output[I] = Interpreter->output_tensor(
1775ce4c9aaSMircea Trofin         OutputsMap[OutputSpec.name() + ":" +
1785ce4c9aaSMircea Trofin                    std::to_string(OutputSpec.port())]);
1795ce4c9aaSMircea Trofin     if (!checkReportAndInvalidate(Output[I], OutputSpec))
1805ce4c9aaSMircea Trofin       return;
1815ce4c9aaSMircea Trofin   }
1825ce4c9aaSMircea Trofin }
1835ce4c9aaSMircea Trofin 
TFModelEvaluator(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const char * Tags)1845ce4c9aaSMircea Trofin TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
1855ce4c9aaSMircea Trofin                                    const std::vector<TensorSpec> &InputSpecs,
1865ce4c9aaSMircea Trofin                                    const std::vector<TensorSpec> &OutputSpecs,
1875ce4c9aaSMircea Trofin                                    const char *Tags)
1881ee3bb17SMircea Trofin     : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, OutputSpecs,
1891ee3bb17SMircea Trofin                                     Tags)) {
1901ee3bb17SMircea Trofin   if (!Impl->isValid())
1911ee3bb17SMircea Trofin     Impl.reset();
1921ee3bb17SMircea Trofin }
1935ce4c9aaSMircea Trofin 
~TFModelEvaluatorImpl()1945ce4c9aaSMircea Trofin TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {}
1955ce4c9aaSMircea Trofin 
checkReportAndInvalidate(const TfLiteTensor * Tensor,const TensorSpec & Spec)1965ce4c9aaSMircea Trofin bool TFModelEvaluatorImpl::checkReportAndInvalidate(const TfLiteTensor *Tensor,
1975ce4c9aaSMircea Trofin                                                     const TensorSpec &Spec) {
1985ce4c9aaSMircea Trofin   if (!Tensor) {
1995ce4c9aaSMircea Trofin     errs() << "Could not find TF_Output named: " + Spec.name();
2005ce4c9aaSMircea Trofin     IsValid = false;
2015ce4c9aaSMircea Trofin   }
2025ce4c9aaSMircea Trofin   if (Spec.getTotalTensorBufferSize() != Tensor->bytes)
2035ce4c9aaSMircea Trofin     IsValid = false;
2045ce4c9aaSMircea Trofin 
2055ce4c9aaSMircea Trofin   // If the total sizes match, there could still be a mismatch in the shape.
2065ce4c9aaSMircea Trofin   // We ignore that for now.
2075ce4c9aaSMircea Trofin 
2085ce4c9aaSMircea Trofin   return IsValid;
2095ce4c9aaSMircea Trofin }
2105ce4c9aaSMircea Trofin 
evaluate()211d4b6fcb3SFangrui Song std::optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
2125ce4c9aaSMircea Trofin   if (!isValid())
2139c444f70SKazu Hirata     return std::nullopt;
2145ce4c9aaSMircea Trofin   return EvaluationResult(Impl->evaluate());
2155ce4c9aaSMircea Trofin }
2165ce4c9aaSMircea Trofin 
getUntypedInput(size_t Index)2175ce4c9aaSMircea Trofin void *TFModelEvaluator::getUntypedInput(size_t Index) {
2185ce4c9aaSMircea Trofin   TfLiteTensor *T = Impl->getInput()[Index];
2195ce4c9aaSMircea Trofin   if (!T)
2205ce4c9aaSMircea Trofin     return nullptr;
2215ce4c9aaSMircea Trofin   return T->data.data;
2225ce4c9aaSMircea Trofin }
2235ce4c9aaSMircea Trofin 
EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl)2245ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::EvaluationResult(
2255ce4c9aaSMircea Trofin     std::unique_ptr<EvaluationResultImpl> Impl)
2265ce4c9aaSMircea Trofin     : Impl(std::move(Impl)) {}
2275ce4c9aaSMircea Trofin 
EvaluationResult(EvaluationResult && Other)2285ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
2295ce4c9aaSMircea Trofin     : Impl(std::move(Other.Impl)) {}
2305ce4c9aaSMircea Trofin 
2315ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult &
operator =(EvaluationResult && Other)2325ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
2335ce4c9aaSMircea Trofin   Impl = std::move(Other.Impl);
2345ce4c9aaSMircea Trofin   return *this;
2355ce4c9aaSMircea Trofin }
2365ce4c9aaSMircea Trofin 
getUntypedTensorValue(size_t Index)2375ce4c9aaSMircea Trofin void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
2385ce4c9aaSMircea Trofin   return Impl->getOutput(Index)->data.data;
2395ce4c9aaSMircea Trofin }
2405ce4c9aaSMircea Trofin 
2415ce4c9aaSMircea Trofin const void *
getUntypedTensorValue(size_t Index) const2425ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
2435ce4c9aaSMircea Trofin   return Impl->getOutput(Index)->data.data;
2445ce4c9aaSMircea Trofin }
2455ce4c9aaSMircea Trofin 
~EvaluationResult()2465ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
~TFModelEvaluator()2475ce4c9aaSMircea Trofin TFModelEvaluator::~TFModelEvaluator() {}
2485ce4c9aaSMircea Trofin 
249edc83a15SKazu Hirata #endif // defined(LLVM_HAVE_TFLITE)
250