1*28bb2193SMircea Trofin //===- TFUtils.cpp - TFLite-based evaluation utilities --------------------===//
25ce4c9aaSMircea Trofin //
35ce4c9aaSMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45ce4c9aaSMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
55ce4c9aaSMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65ce4c9aaSMircea Trofin //
75ce4c9aaSMircea Trofin //===----------------------------------------------------------------------===//
85ce4c9aaSMircea Trofin //
9*28bb2193SMircea Trofin // This file implements utilities for interfacing with TFLite.
105ce4c9aaSMircea Trofin //
115ce4c9aaSMircea Trofin //===----------------------------------------------------------------------===//
125ce4c9aaSMircea Trofin #include "llvm/Config/config.h"
135ce4c9aaSMircea Trofin #if defined(LLVM_HAVE_TFLITE)
145ce4c9aaSMircea Trofin
155ce4c9aaSMircea Trofin #include "llvm/ADT/Twine.h"
165ce4c9aaSMircea Trofin #include "llvm/Analysis/Utils/TFUtils.h"
175ce4c9aaSMircea Trofin #include "llvm/Support/Base64.h"
185ce4c9aaSMircea Trofin #include "llvm/Support/CommandLine.h"
195ce4c9aaSMircea Trofin #include "llvm/Support/Debug.h"
205ce4c9aaSMircea Trofin #include "llvm/Support/JSON.h"
215ce4c9aaSMircea Trofin #include "llvm/Support/MemoryBuffer.h"
225ce4c9aaSMircea Trofin #include "llvm/Support/Path.h"
235ce4c9aaSMircea Trofin #include "llvm/Support/raw_ostream.h"
245ce4c9aaSMircea Trofin
255ce4c9aaSMircea Trofin #include "tensorflow/lite/interpreter.h"
265ce4c9aaSMircea Trofin #include "tensorflow/lite/kernels/register.h"
275ce4c9aaSMircea Trofin #include "tensorflow/lite/model.h"
285ce4c9aaSMircea Trofin #include "tensorflow/lite/model_builder.h"
295ce4c9aaSMircea Trofin #include "tensorflow/lite/op_resolver.h"
30a219a8a8SMircea Trofin #include "tensorflow/lite/logger.h"
315ce4c9aaSMircea Trofin
325ce4c9aaSMircea Trofin #include <cassert>
335ce4c9aaSMircea Trofin #include <numeric>
349c444f70SKazu Hirata #include <optional>
355ce4c9aaSMircea Trofin
365ce4c9aaSMircea Trofin using namespace llvm;
375ce4c9aaSMircea Trofin
385ce4c9aaSMircea Trofin namespace llvm {
395ce4c9aaSMircea Trofin class EvaluationResultImpl {
405ce4c9aaSMircea Trofin public:
EvaluationResultImpl(const std::vector<const TfLiteTensor * > & Outputs)415ce4c9aaSMircea Trofin EvaluationResultImpl(const std::vector<const TfLiteTensor *> &Outputs)
425ce4c9aaSMircea Trofin : Outputs(Outputs){};
435ce4c9aaSMircea Trofin
getOutput(size_t I)445ce4c9aaSMircea Trofin const TfLiteTensor *getOutput(size_t I) { return Outputs[I]; }
455ce4c9aaSMircea Trofin
465ce4c9aaSMircea Trofin EvaluationResultImpl(const EvaluationResultImpl &) = delete;
475ce4c9aaSMircea Trofin EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
485ce4c9aaSMircea Trofin
495ce4c9aaSMircea Trofin private:
505ce4c9aaSMircea Trofin const std::vector<const TfLiteTensor *> Outputs;
515ce4c9aaSMircea Trofin };
525ce4c9aaSMircea Trofin
535ce4c9aaSMircea Trofin class TFModelEvaluatorImpl {
545ce4c9aaSMircea Trofin public:
555ce4c9aaSMircea Trofin TFModelEvaluatorImpl(StringRef SavedModelPath,
565ce4c9aaSMircea Trofin const std::vector<TensorSpec> &InputSpecs,
571ee3bb17SMircea Trofin const std::vector<TensorSpec> &OutputSpecs,
581ee3bb17SMircea Trofin const char *Tags);
595ce4c9aaSMircea Trofin
isValid() const605ce4c9aaSMircea Trofin bool isValid() const { return IsValid; }
outputSize() const615ce4c9aaSMircea Trofin size_t outputSize() const { return Output.size(); }
625ce4c9aaSMircea Trofin
evaluate()635ce4c9aaSMircea Trofin std::unique_ptr<EvaluationResultImpl> evaluate() {
645ce4c9aaSMircea Trofin Interpreter->Invoke();
655ce4c9aaSMircea Trofin return std::make_unique<EvaluationResultImpl>(Output);
665ce4c9aaSMircea Trofin }
675ce4c9aaSMircea Trofin
getInput() const685ce4c9aaSMircea Trofin const std::vector<TfLiteTensor *> &getInput() const { return Input; }
695ce4c9aaSMircea Trofin
705ce4c9aaSMircea Trofin ~TFModelEvaluatorImpl();
715ce4c9aaSMircea Trofin
725ce4c9aaSMircea Trofin private:
735ce4c9aaSMircea Trofin std::unique_ptr<tflite::FlatBufferModel> Model;
745ce4c9aaSMircea Trofin
755ce4c9aaSMircea Trofin /// The objects necessary for carrying out an evaluation of the SavedModel.
765ce4c9aaSMircea Trofin /// They are expensive to set up, and we maintain them accross all the
775ce4c9aaSMircea Trofin /// evaluations of the model.
785ce4c9aaSMircea Trofin std::unique_ptr<tflite::Interpreter> Interpreter;
795ce4c9aaSMircea Trofin
805ce4c9aaSMircea Trofin /// The input tensors. We set up the tensors once and just mutate theirs
815ce4c9aaSMircea Trofin /// scalars before each evaluation. The input tensors keep their value after
825ce4c9aaSMircea Trofin /// an evaluation.
835ce4c9aaSMircea Trofin std::vector<TfLiteTensor *> Input;
845ce4c9aaSMircea Trofin
855ce4c9aaSMircea Trofin /// The output nodes.
865ce4c9aaSMircea Trofin std::vector<const TfLiteTensor *> Output;
875ce4c9aaSMircea Trofin
invalidate()885ce4c9aaSMircea Trofin void invalidate() { IsValid = false; }
895ce4c9aaSMircea Trofin
905ce4c9aaSMircea Trofin bool IsValid = true;
915ce4c9aaSMircea Trofin
925ce4c9aaSMircea Trofin /// Reusable utility for ensuring we can bind the requested Name to a node in
935ce4c9aaSMircea Trofin /// the SavedModel Graph.
945ce4c9aaSMircea Trofin bool checkReportAndInvalidate(const TfLiteTensor *Tensor,
955ce4c9aaSMircea Trofin const TensorSpec &Spec);
965ce4c9aaSMircea Trofin };
975ce4c9aaSMircea Trofin
985ce4c9aaSMircea Trofin } // namespace llvm
995ce4c9aaSMircea Trofin
TFModelEvaluatorImpl(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const char * Tags="serve")1005ce4c9aaSMircea Trofin TFModelEvaluatorImpl::TFModelEvaluatorImpl(
1015ce4c9aaSMircea Trofin StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
1021ee3bb17SMircea Trofin const std::vector<TensorSpec> &OutputSpecs, const char *Tags = "serve")
1031ee3bb17SMircea Trofin : Input(InputSpecs.size()), Output(OutputSpecs.size()) {
104a219a8a8SMircea Trofin // INFO and DEBUG messages could be numerous and not particularly interesting
105a219a8a8SMircea Trofin tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_WARNING);
1065ce4c9aaSMircea Trofin // FIXME: make ErrorReporter a member (may also need subclassing
1075ce4c9aaSMircea Trofin // StatefulErrorReporter) to easily get the latest error status, for
1085ce4c9aaSMircea Trofin // debugging.
1095ce4c9aaSMircea Trofin tflite::StderrReporter ErrorReporter;
1105ce4c9aaSMircea Trofin SmallVector<char, 128> TFLitePathBuff;
1115ce4c9aaSMircea Trofin llvm::sys::path::append(TFLitePathBuff, SavedModelPath, "model.tflite");
1125ce4c9aaSMircea Trofin StringRef TFLitePath(TFLitePathBuff.data(), TFLitePathBuff.size());
1135ce4c9aaSMircea Trofin Model = tflite::FlatBufferModel::BuildFromFile(TFLitePath.str().c_str(),
1145ce4c9aaSMircea Trofin &ErrorReporter);
1155ce4c9aaSMircea Trofin if (!Model) {
1165ce4c9aaSMircea Trofin invalidate();
1175ce4c9aaSMircea Trofin return;
1185ce4c9aaSMircea Trofin }
1195ce4c9aaSMircea Trofin
1205ce4c9aaSMircea Trofin tflite::ops::builtin::BuiltinOpResolver Resolver;
1215ce4c9aaSMircea Trofin tflite::InterpreterBuilder Builder(*Model, Resolver);
1225ce4c9aaSMircea Trofin Builder(&Interpreter);
1235ce4c9aaSMircea Trofin
12417095dfeSJacob Hegna if (!Interpreter) {
12517095dfeSJacob Hegna invalidate();
12617095dfeSJacob Hegna return;
12717095dfeSJacob Hegna }
12817095dfeSJacob Hegna
1299d93a98fSJacob Hegna // We assume the input buffers are valid for the lifetime of the interpreter.
1309d93a98fSJacob Hegna // By default, tflite allocates memory in an arena and will periodically take
1319d93a98fSJacob Hegna // away memory and reallocate it in a different location after evaluations in
1329d93a98fSJacob Hegna // order to improve utilization of the buffers owned in the arena. So, we
1339d93a98fSJacob Hegna // explicitly mark our input buffers as persistent to avoid this behavior.
1349d93a98fSJacob Hegna for (size_t I = 0; I < Interpreter->inputs().size(); ++I)
1359d93a98fSJacob Hegna Interpreter->tensor(I)->allocation_type =
1369d93a98fSJacob Hegna TfLiteAllocationType::kTfLiteArenaRwPersistent;
1379d93a98fSJacob Hegna
13817095dfeSJacob Hegna if (Interpreter->AllocateTensors() != TfLiteStatus::kTfLiteOk) {
1395ce4c9aaSMircea Trofin invalidate();
1405ce4c9aaSMircea Trofin return;
1415ce4c9aaSMircea Trofin }
1425ce4c9aaSMircea Trofin // Known inputs and outputs
1435ce4c9aaSMircea Trofin StringMap<int> InputsMap;
1445ce4c9aaSMircea Trofin StringMap<int> OutputsMap;
1455ce4c9aaSMircea Trofin for (size_t I = 0; I < Interpreter->inputs().size(); ++I)
1465ce4c9aaSMircea Trofin InputsMap[Interpreter->GetInputName(I)] = I;
1475ce4c9aaSMircea Trofin for (size_t I = 0; I < Interpreter->outputs().size(); ++I)
1485ce4c9aaSMircea Trofin OutputsMap[Interpreter->GetOutputName(I)] = I;
1495ce4c9aaSMircea Trofin
150ec83c7e3SAiden Grossman size_t NumberFeaturesPassed = 0;
1515ce4c9aaSMircea Trofin for (size_t I = 0; I < InputSpecs.size(); ++I) {
1525ce4c9aaSMircea Trofin auto &InputSpec = InputSpecs[I];
1535ce4c9aaSMircea Trofin auto MapI = InputsMap.find(InputSpec.name() + ":" +
1545ce4c9aaSMircea Trofin std::to_string(InputSpec.port()));
1555ce4c9aaSMircea Trofin if (MapI == InputsMap.end()) {
1565ce4c9aaSMircea Trofin Input[I] = nullptr;
1575ce4c9aaSMircea Trofin continue;
1585ce4c9aaSMircea Trofin }
1595ce4c9aaSMircea Trofin Input[I] = Interpreter->tensor(MapI->second);
1605ce4c9aaSMircea Trofin if (!checkReportAndInvalidate(Input[I], InputSpec))
1615ce4c9aaSMircea Trofin return;
1625ce4c9aaSMircea Trofin std::memset(Input[I]->data.data, 0,
1635ce4c9aaSMircea Trofin InputSpecs[I].getTotalTensorBufferSize());
164ec83c7e3SAiden Grossman ++NumberFeaturesPassed;
165ec83c7e3SAiden Grossman }
166ec83c7e3SAiden Grossman
167ec83c7e3SAiden Grossman if (NumberFeaturesPassed < Interpreter->inputs().size()) {
168ec83c7e3SAiden Grossman // we haven't passed all the required features to the model, throw an error.
169ec83c7e3SAiden Grossman errs() << "Required feature(s) have not been passed to the ML model";
170ec83c7e3SAiden Grossman invalidate();
171ec83c7e3SAiden Grossman return;
1725ce4c9aaSMircea Trofin }
1735ce4c9aaSMircea Trofin
1741ee3bb17SMircea Trofin for (size_t I = 0; I < OutputSpecs.size(); ++I) {
1751ee3bb17SMircea Trofin const auto &OutputSpec = OutputSpecs[I];
1765ce4c9aaSMircea Trofin Output[I] = Interpreter->output_tensor(
1775ce4c9aaSMircea Trofin OutputsMap[OutputSpec.name() + ":" +
1785ce4c9aaSMircea Trofin std::to_string(OutputSpec.port())]);
1795ce4c9aaSMircea Trofin if (!checkReportAndInvalidate(Output[I], OutputSpec))
1805ce4c9aaSMircea Trofin return;
1815ce4c9aaSMircea Trofin }
1825ce4c9aaSMircea Trofin }
1835ce4c9aaSMircea Trofin
TFModelEvaluator(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const char * Tags)1845ce4c9aaSMircea Trofin TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
1855ce4c9aaSMircea Trofin const std::vector<TensorSpec> &InputSpecs,
1865ce4c9aaSMircea Trofin const std::vector<TensorSpec> &OutputSpecs,
1875ce4c9aaSMircea Trofin const char *Tags)
1881ee3bb17SMircea Trofin : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, OutputSpecs,
1891ee3bb17SMircea Trofin Tags)) {
1901ee3bb17SMircea Trofin if (!Impl->isValid())
1911ee3bb17SMircea Trofin Impl.reset();
1921ee3bb17SMircea Trofin }
1935ce4c9aaSMircea Trofin
~TFModelEvaluatorImpl()1945ce4c9aaSMircea Trofin TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {}
1955ce4c9aaSMircea Trofin
checkReportAndInvalidate(const TfLiteTensor * Tensor,const TensorSpec & Spec)1965ce4c9aaSMircea Trofin bool TFModelEvaluatorImpl::checkReportAndInvalidate(const TfLiteTensor *Tensor,
1975ce4c9aaSMircea Trofin const TensorSpec &Spec) {
1985ce4c9aaSMircea Trofin if (!Tensor) {
1995ce4c9aaSMircea Trofin errs() << "Could not find TF_Output named: " + Spec.name();
2005ce4c9aaSMircea Trofin IsValid = false;
2015ce4c9aaSMircea Trofin }
2025ce4c9aaSMircea Trofin if (Spec.getTotalTensorBufferSize() != Tensor->bytes)
2035ce4c9aaSMircea Trofin IsValid = false;
2045ce4c9aaSMircea Trofin
2055ce4c9aaSMircea Trofin // If the total sizes match, there could still be a mismatch in the shape.
2065ce4c9aaSMircea Trofin // We ignore that for now.
2075ce4c9aaSMircea Trofin
2085ce4c9aaSMircea Trofin return IsValid;
2095ce4c9aaSMircea Trofin }
2105ce4c9aaSMircea Trofin
evaluate()211d4b6fcb3SFangrui Song std::optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
2125ce4c9aaSMircea Trofin if (!isValid())
2139c444f70SKazu Hirata return std::nullopt;
2145ce4c9aaSMircea Trofin return EvaluationResult(Impl->evaluate());
2155ce4c9aaSMircea Trofin }
2165ce4c9aaSMircea Trofin
getUntypedInput(size_t Index)2175ce4c9aaSMircea Trofin void *TFModelEvaluator::getUntypedInput(size_t Index) {
2185ce4c9aaSMircea Trofin TfLiteTensor *T = Impl->getInput()[Index];
2195ce4c9aaSMircea Trofin if (!T)
2205ce4c9aaSMircea Trofin return nullptr;
2215ce4c9aaSMircea Trofin return T->data.data;
2225ce4c9aaSMircea Trofin }
2235ce4c9aaSMircea Trofin
EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl)2245ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::EvaluationResult(
2255ce4c9aaSMircea Trofin std::unique_ptr<EvaluationResultImpl> Impl)
2265ce4c9aaSMircea Trofin : Impl(std::move(Impl)) {}
2275ce4c9aaSMircea Trofin
EvaluationResult(EvaluationResult && Other)2285ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
2295ce4c9aaSMircea Trofin : Impl(std::move(Other.Impl)) {}
2305ce4c9aaSMircea Trofin
2315ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult &
operator =(EvaluationResult && Other)2325ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
2335ce4c9aaSMircea Trofin Impl = std::move(Other.Impl);
2345ce4c9aaSMircea Trofin return *this;
2355ce4c9aaSMircea Trofin }
2365ce4c9aaSMircea Trofin
getUntypedTensorValue(size_t Index)2375ce4c9aaSMircea Trofin void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
2385ce4c9aaSMircea Trofin return Impl->getOutput(Index)->data.data;
2395ce4c9aaSMircea Trofin }
2405ce4c9aaSMircea Trofin
2415ce4c9aaSMircea Trofin const void *
getUntypedTensorValue(size_t Index) const2425ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
2435ce4c9aaSMircea Trofin return Impl->getOutput(Index)->data.data;
2445ce4c9aaSMircea Trofin }
2455ce4c9aaSMircea Trofin
~EvaluationResult()2465ce4c9aaSMircea Trofin TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
~TFModelEvaluator()2475ce4c9aaSMircea Trofin TFModelEvaluator::~TFModelEvaluator() {}
2485ce4c9aaSMircea Trofin
249edc83a15SKazu Hirata #endif // defined(LLVM_HAVE_TFLITE)
250