1bdceefe9SMircea Trofin //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// 2bdceefe9SMircea Trofin // 3bdceefe9SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4bdceefe9SMircea Trofin // See https://llvm.org/LICENSE.txt for license information. 5bdceefe9SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6bdceefe9SMircea Trofin // 7bdceefe9SMircea Trofin //===----------------------------------------------------------------------===// 8bdceefe9SMircea Trofin // 9bdceefe9SMircea Trofin 10bdceefe9SMircea Trofin #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H 11bdceefe9SMircea Trofin #define LLVM_ANALYSIS_MLMODELRUNNER_H 12bdceefe9SMircea Trofin 13c35ad9eeSMircea Trofin #include "llvm/Analysis/TensorSpec.h" 14bdceefe9SMircea Trofin #include "llvm/IR/PassManager.h" 15bdceefe9SMircea Trofin 16bdceefe9SMircea Trofin namespace llvm { 17c6f0940dSBill Wendling class LLVMContext; 18bdceefe9SMircea Trofin 19bdceefe9SMircea Trofin /// MLModelRunner interface: abstraction of a mechanism for evaluating a 2028bb2193SMircea Trofin /// ML model. More abstractly, evaluating a function that has as tensors as 2128bb2193SMircea Trofin /// arguments, described via TensorSpecs, and returns a tensor. Currently, the 2228bb2193SMircea Trofin /// latter is assumed to be a scalar, in absence of more elaborate scenarios. 23db5aceb9SMircea Trofin /// NOTE: feature indices are expected to be consistent all accross 24db5aceb9SMircea Trofin /// MLModelRunners (pertaining to the same model), and also Loggers (see 25db5aceb9SMircea Trofin /// TFUtils.h) 26bdceefe9SMircea Trofin class MLModelRunner { 27bdceefe9SMircea Trofin public: 28bdceefe9SMircea Trofin // Disallows copy and assign. 29bdceefe9SMircea Trofin MLModelRunner(const MLModelRunner &) = delete; 30bdceefe9SMircea Trofin MLModelRunner &operator=(const MLModelRunner &) = delete; 31bdceefe9SMircea Trofin virtual ~MLModelRunner() = default; 32bdceefe9SMircea Trofin 33059e0347SMircea Trofin template <typename T> T evaluate() { 34059e0347SMircea Trofin return *reinterpret_cast<T *>(evaluateUntyped()); 35059e0347SMircea Trofin } 36059e0347SMircea Trofin 37059e0347SMircea Trofin template <typename T, typename I> T *getTensor(I FeatureID) { 38059e0347SMircea Trofin return reinterpret_cast<T *>( 39059e0347SMircea Trofin getTensorUntyped(static_cast<size_t>(FeatureID))); 40059e0347SMircea Trofin } 41059e0347SMircea Trofin 42059e0347SMircea Trofin template <typename T, typename I> const T *getTensor(I FeatureID) const { 43059e0347SMircea Trofin return reinterpret_cast<const T *>( 44059e0347SMircea Trofin getTensorUntyped(static_cast<size_t>(FeatureID))); 45059e0347SMircea Trofin } 46bdceefe9SMircea Trofin 47c35ad9eeSMircea Trofin void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; } 4868ac7b17SMircea Trofin const void *getTensorUntyped(size_t Index) const { 4968ac7b17SMircea Trofin return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index); 5068ac7b17SMircea Trofin } 5168ac7b17SMircea Trofin 525b8dc7c8SMircea Trofin enum class Kind : int { Unknown, Release, Development, NoOp, Interactive }; 53a120fdd3SMircea Trofin Kind getKind() const { return Type; } 545fd51fcbSMircea Trofin virtual void switchContext(StringRef Name) {} 55a120fdd3SMircea Trofin 56bdceefe9SMircea Trofin protected: 57*f32e5bdcSMircea Trofin MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NumInputs) 58*f32e5bdcSMircea Trofin : Ctx(Ctx), Type(Type), InputBuffers(NumInputs) { 59a120fdd3SMircea Trofin assert(Type != Kind::Unknown); 60a120fdd3SMircea Trofin } 61059e0347SMircea Trofin virtual void *evaluateUntyped() = 0; 62bdceefe9SMircea Trofin 63c35ad9eeSMircea Trofin void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, 64c35ad9eeSMircea Trofin void *Buffer) { 65c35ad9eeSMircea Trofin if (!Buffer) { 66c35ad9eeSMircea Trofin OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize()); 67c35ad9eeSMircea Trofin Buffer = OwnedBuffers.back().data(); 68c35ad9eeSMircea Trofin } 69c35ad9eeSMircea Trofin InputBuffers[Index] = Buffer; 70c35ad9eeSMircea Trofin } 71c35ad9eeSMircea Trofin 72bdceefe9SMircea Trofin LLVMContext &Ctx; 73a120fdd3SMircea Trofin const Kind Type; 74c35ad9eeSMircea Trofin 75c35ad9eeSMircea Trofin private: 76c35ad9eeSMircea Trofin std::vector<void *> InputBuffers; 77c35ad9eeSMircea Trofin std::vector<std::vector<char *>> OwnedBuffers; 78bdceefe9SMircea Trofin }; 79bdceefe9SMircea Trofin } // namespace llvm 80bdceefe9SMircea Trofin 81bdceefe9SMircea Trofin #endif // LLVM_ANALYSIS_MLMODELRUNNER_H 82