1 //===- ReleaseModeModelRunner.h - Fast, precompiled model runner ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a model runner wrapping an AOT compiled ML model. 10 // Only inference is supported. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H 15 #define LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H 16 17 #include "llvm/ADT/StringExtras.h" 18 #include "llvm/Analysis/MLModelRunner.h" 19 #include "llvm/Analysis/TensorSpec.h" 20 #include "llvm/Support/ErrorHandling.h" 21 #include "llvm/Support/MD5.h" 22 23 #include <memory> 24 25 namespace llvm { 26 27 /// ReleaseModeModelRunner - production mode implementation of the 28 /// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution. 29 struct EmbeddedModelRunnerOptions { 30 /// Feed and Fetch feature prefixes - i.e. a feature named "foo" will be 31 /// looked up as {FeedPrefix}_foo; and the output named "bar" will be looked 32 /// up as {FetchPrefix}_bar 33 StringRef FeedPrefix = "feed_"; 34 StringRef FetchPrefix = "fetch_"; 35 36 /// ModelSelector is the name (recognized by the AOT-ed model) of a sub-model 37 /// to use. "" is allowed if the model doesn't support sub-models. 38 StringRef ModelSelector = ""; 39 40 EmbeddedModelRunnerOptions &setFeedPrefix(StringRef Value) { 41 FeedPrefix = Value; 42 return *this; 43 } 44 EmbeddedModelRunnerOptions &setFetchPrefix(StringRef Value) { 45 FetchPrefix = Value; 46 return *this; 47 } 48 EmbeddedModelRunnerOptions &setModelSelector(StringRef Value) { 49 ModelSelector = Value; 50 return *this; 51 } 52 }; 53 54 template <class TGen> 55 class ReleaseModeModelRunner final : public MLModelRunner { 56 public: 57 /// FeatureNames' type should be an indexed collection of std::string, like 58 /// std::array or std::vector, that has a size() method. 59 template <class FType> 60 ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec, 61 StringRef DecisionName, 62 const EmbeddedModelRunnerOptions &Options = {}) 63 : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size() + 1), 64 CompiledModel(std::make_unique<TGen>()) { 65 assert(CompiledModel && "The CompiledModel should be valid"); 66 // Set up the model_selector past all the InputSpecs in all cases. 67 // - if the model doesn't have such a feature, but the user requested it, 68 // we report error. Same if the model supports it but the user didn't 69 // specify it 70 // - finally, we compute the MD5 hash of the user input and set the value 71 // of the model selector to {high, low} 72 bool InputIsPresent = true; 73 populateTensor(InputSpec.size(), 74 TensorSpec::createSpec<uint64_t>("model_selector", {2}), 75 Options.FeedPrefix, InputIsPresent); 76 77 // If we hit the "report an error" cases outlined above, continue with the 78 // set up in case there's some custom diagnostics handler installed and it 79 // doesn't promptly exit. 80 if (Options.ModelSelector.empty() && InputIsPresent) 81 Ctx.emitError( 82 "A model selector was not specified but the underlying model " 83 "requires selecting one because it exposes a model_selector input"); 84 uint64_t High = 0; 85 uint64_t Low = 0; 86 if (!Options.ModelSelector.empty()) { 87 if (!InputIsPresent) 88 Ctx.emitError("A model selector was specified but the underlying model " 89 "does not expose a model_selector input"); 90 const auto Hash = MD5::hash(arrayRefFromStringRef(Options.ModelSelector)); 91 High = Hash.high(); 92 Low = Hash.low(); 93 } 94 getTensor<uint64_t>(InputSpec.size())[0] = High; 95 getTensor<uint64_t>(InputSpec.size())[1] = Low; 96 // At this point, the model selector is set up. If the user didn't provide 97 // one, but the model has a model_selector, it'll be set to (0, 0) which 98 // the composite model should treat as error as part of its implementation 99 // (but that should only matter if there is a custom handler that doesn't 100 // exit on error) 101 for (size_t I = 0; I < InputSpec.size(); ++I) 102 populateTensor(I, InputSpec[I], Options.FeedPrefix, InputIsPresent); 103 104 ResultIndex = CompiledModel->LookupResultIndex(Options.FetchPrefix.str() + 105 DecisionName.str()); 106 assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model"); 107 } 108 109 virtual ~ReleaseModeModelRunner() = default; 110 111 static bool classof(const MLModelRunner *R) { 112 return R->getKind() == MLModelRunner::Kind::Release; 113 } 114 115 private: 116 // fetch the model-provided buffer for the given Spec, or let MLModelRunner 117 // create a scratch buffer. Indicate back to the caller if the model had that 118 // input in the first place. 119 void populateTensor(size_t Pos, const TensorSpec &Spec, StringRef Prefix, 120 bool &InputIsPresent) { 121 const int Index = 122 CompiledModel->LookupArgIndex((Prefix + Spec.name()).str()); 123 void *Buffer = nullptr; 124 InputIsPresent = Index >= 0; 125 if (InputIsPresent) 126 Buffer = CompiledModel->arg_data(Index); 127 setUpBufferForTensor(Pos, Spec, Buffer); 128 } 129 130 void *evaluateUntyped() override { 131 CompiledModel->Run(); 132 return CompiledModel->result_data(ResultIndex); 133 } 134 135 int32_t ResultIndex = -1; 136 std::unique_ptr<TGen> CompiledModel; 137 }; 138 139 /// A mock class satisfying the interface expected by ReleaseModeModelRunner for 140 /// its `TGen` parameter. Useful to avoid conditional compilation complexity, as 141 /// a compile-time replacement for a real AOT-ed model. 142 class NoopSavedModelImpl final { 143 #define NOOP_MODEL_ERRMSG \ 144 "The mock AOT-ed saved model is a compile-time stub and should not be " \ 145 "called." 146 147 public: 148 NoopSavedModelImpl() = default; 149 int LookupArgIndex(const std::string &) { llvm_unreachable(NOOP_MODEL_ERRMSG); } 150 int LookupResultIndex(const std::string &) { llvm_unreachable(NOOP_MODEL_ERRMSG); } 151 void Run() { llvm_unreachable(NOOP_MODEL_ERRMSG); } 152 void *result_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); } 153 void *arg_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); } 154 #undef NOOP_MODEL_ERRMSG 155 }; 156 157 template <class T> bool isEmbeddedModelEvaluatorValid() { return true; } 158 159 template <> inline bool isEmbeddedModelEvaluatorValid<NoopSavedModelImpl>() { 160 return false; 161 } 162 } // namespace llvm 163 164 #endif // LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H 165