xref: /llvm-project/llvm/include/llvm/Analysis/MLModelRunner.h (revision f32e5bdcefcff80f4296f8f4abedc37dcda36d53)
1bdceefe9SMircea Trofin //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===//
2bdceefe9SMircea Trofin //
3bdceefe9SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdceefe9SMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
5bdceefe9SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdceefe9SMircea Trofin //
7bdceefe9SMircea Trofin //===----------------------------------------------------------------------===//
8bdceefe9SMircea Trofin //
9bdceefe9SMircea Trofin 
10bdceefe9SMircea Trofin #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H
11bdceefe9SMircea Trofin #define LLVM_ANALYSIS_MLMODELRUNNER_H
12bdceefe9SMircea Trofin 
13c35ad9eeSMircea Trofin #include "llvm/Analysis/TensorSpec.h"
14bdceefe9SMircea Trofin #include "llvm/IR/PassManager.h"
15bdceefe9SMircea Trofin 
16bdceefe9SMircea Trofin namespace llvm {
17c6f0940dSBill Wendling class LLVMContext;
18bdceefe9SMircea Trofin 
19bdceefe9SMircea Trofin /// MLModelRunner interface: abstraction of a mechanism for evaluating a
2028bb2193SMircea Trofin /// ML model. More abstractly, evaluating a function that has as tensors as
2128bb2193SMircea Trofin /// arguments, described via TensorSpecs, and returns a tensor. Currently, the
2228bb2193SMircea Trofin /// latter is assumed to be a scalar, in absence of more elaborate scenarios.
23db5aceb9SMircea Trofin /// NOTE: feature indices are expected to be consistent all accross
24db5aceb9SMircea Trofin /// MLModelRunners (pertaining to the same model), and also Loggers (see
25db5aceb9SMircea Trofin /// TFUtils.h)
26bdceefe9SMircea Trofin class MLModelRunner {
27bdceefe9SMircea Trofin public:
28bdceefe9SMircea Trofin   // Disallows copy and assign.
29bdceefe9SMircea Trofin   MLModelRunner(const MLModelRunner &) = delete;
30bdceefe9SMircea Trofin   MLModelRunner &operator=(const MLModelRunner &) = delete;
31bdceefe9SMircea Trofin   virtual ~MLModelRunner() = default;
32bdceefe9SMircea Trofin 
33059e0347SMircea Trofin   template <typename T> T evaluate() {
34059e0347SMircea Trofin     return *reinterpret_cast<T *>(evaluateUntyped());
35059e0347SMircea Trofin   }
36059e0347SMircea Trofin 
37059e0347SMircea Trofin   template <typename T, typename I> T *getTensor(I FeatureID) {
38059e0347SMircea Trofin     return reinterpret_cast<T *>(
39059e0347SMircea Trofin         getTensorUntyped(static_cast<size_t>(FeatureID)));
40059e0347SMircea Trofin   }
41059e0347SMircea Trofin 
42059e0347SMircea Trofin   template <typename T, typename I> const T *getTensor(I FeatureID) const {
43059e0347SMircea Trofin     return reinterpret_cast<const T *>(
44059e0347SMircea Trofin         getTensorUntyped(static_cast<size_t>(FeatureID)));
45059e0347SMircea Trofin   }
46bdceefe9SMircea Trofin 
47c35ad9eeSMircea Trofin   void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; }
4868ac7b17SMircea Trofin   const void *getTensorUntyped(size_t Index) const {
4968ac7b17SMircea Trofin     return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index);
5068ac7b17SMircea Trofin   }
5168ac7b17SMircea Trofin 
525b8dc7c8SMircea Trofin   enum class Kind : int { Unknown, Release, Development, NoOp, Interactive };
53a120fdd3SMircea Trofin   Kind getKind() const { return Type; }
545fd51fcbSMircea Trofin   virtual void switchContext(StringRef Name) {}
55a120fdd3SMircea Trofin 
56bdceefe9SMircea Trofin protected:
57*f32e5bdcSMircea Trofin   MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NumInputs)
58*f32e5bdcSMircea Trofin       : Ctx(Ctx), Type(Type), InputBuffers(NumInputs) {
59a120fdd3SMircea Trofin     assert(Type != Kind::Unknown);
60a120fdd3SMircea Trofin   }
61059e0347SMircea Trofin   virtual void *evaluateUntyped() = 0;
62bdceefe9SMircea Trofin 
63c35ad9eeSMircea Trofin   void setUpBufferForTensor(size_t Index, const TensorSpec &Spec,
64c35ad9eeSMircea Trofin                             void *Buffer) {
65c35ad9eeSMircea Trofin     if (!Buffer) {
66c35ad9eeSMircea Trofin       OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize());
67c35ad9eeSMircea Trofin       Buffer = OwnedBuffers.back().data();
68c35ad9eeSMircea Trofin     }
69c35ad9eeSMircea Trofin     InputBuffers[Index] = Buffer;
70c35ad9eeSMircea Trofin   }
71c35ad9eeSMircea Trofin 
72bdceefe9SMircea Trofin   LLVMContext &Ctx;
73a120fdd3SMircea Trofin   const Kind Type;
74c35ad9eeSMircea Trofin 
75c35ad9eeSMircea Trofin private:
76c35ad9eeSMircea Trofin   std::vector<void *> InputBuffers;
77c35ad9eeSMircea Trofin   std::vector<std::vector<char *>> OwnedBuffers;
78bdceefe9SMircea Trofin };
79bdceefe9SMircea Trofin } // namespace llvm
80bdceefe9SMircea Trofin 
81bdceefe9SMircea Trofin #endif // LLVM_ANALYSIS_MLMODELRUNNER_H
82