xref: /llvm-project/llvm/lib/CodeGen/MLRegAllocPriorityAdvisor.cpp (revision 4010f894a1e880f88bda78a49a8bece5affaa848)
165b40f27SMatt Arsenault //===- MLRegAllocPriorityAdvisor.cpp - ML priority advisor-----------------===//
265b40f27SMatt Arsenault //
365b40f27SMatt Arsenault // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
465b40f27SMatt Arsenault // See https://llvm.org/LICENSE.txt for license information.
565b40f27SMatt Arsenault // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
665b40f27SMatt Arsenault //
765b40f27SMatt Arsenault //===----------------------------------------------------------------------===//
865b40f27SMatt Arsenault //
965b40f27SMatt Arsenault // Implementation of the ML priority advisor and reward injection pass
1065b40f27SMatt Arsenault //
1165b40f27SMatt Arsenault //===----------------------------------------------------------------------===//
1265b40f27SMatt Arsenault 
1365b40f27SMatt Arsenault #include "AllocationOrder.h"
1465b40f27SMatt Arsenault #include "RegAllocGreedy.h"
1565b40f27SMatt Arsenault #include "RegAllocPriorityAdvisor.h"
1665b40f27SMatt Arsenault #include "llvm/Analysis/AliasAnalysis.h"
1765b40f27SMatt Arsenault #include "llvm/Analysis/InteractiveModelRunner.h"
1865b40f27SMatt Arsenault #include "llvm/Analysis/MLModelRunner.h"
1965b40f27SMatt Arsenault #include "llvm/Analysis/ReleaseModeModelRunner.h"
2065b40f27SMatt Arsenault #include "llvm/Analysis/TensorSpec.h"
2165b40f27SMatt Arsenault #include "llvm/CodeGen/CalcSpillWeights.h"
2265b40f27SMatt Arsenault #include "llvm/CodeGen/LiveRegMatrix.h"
2365b40f27SMatt Arsenault #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
2465b40f27SMatt Arsenault #include "llvm/CodeGen/MachineFunction.h"
2565b40f27SMatt Arsenault #include "llvm/CodeGen/MachineLoopInfo.h"
2665b40f27SMatt Arsenault #include "llvm/CodeGen/MachineRegisterInfo.h"
2765b40f27SMatt Arsenault #include "llvm/CodeGen/Passes.h"
2865b40f27SMatt Arsenault #include "llvm/CodeGen/RegisterClassInfo.h"
2965b40f27SMatt Arsenault #include "llvm/CodeGen/SlotIndexes.h"
3065b40f27SMatt Arsenault #include "llvm/CodeGen/VirtRegMap.h"
3165b40f27SMatt Arsenault #include "llvm/InitializePasses.h"
3265b40f27SMatt Arsenault #include "llvm/Pass.h"
3365b40f27SMatt Arsenault #include "llvm/PassRegistry.h"
3465b40f27SMatt Arsenault #include "llvm/Support/CommandLine.h"
3565b40f27SMatt Arsenault 
3665b40f27SMatt Arsenault #if defined(LLVM_HAVE_TFLITE)
3765b40f27SMatt Arsenault #include "llvm/Analysis/ModelUnderTrainingRunner.h"
3865b40f27SMatt Arsenault #include "llvm/Analysis/NoInferenceModelRunner.h"
3965b40f27SMatt Arsenault #include "llvm/Analysis/Utils/TrainingLogger.h"
400606c64dSNikita Popov #include "llvm/IR/Module.h"
4165b40f27SMatt Arsenault #endif
4265b40f27SMatt Arsenault 
4365b40f27SMatt Arsenault using namespace llvm;
4465b40f27SMatt Arsenault 
4565b40f27SMatt Arsenault static cl::opt<std::string> InteractiveChannelBaseName(
4665b40f27SMatt Arsenault     "regalloc-priority-interactive-channel-base", cl::Hidden,
4765b40f27SMatt Arsenault     cl::desc(
4865b40f27SMatt Arsenault         "Base file path for the interactive mode. The incoming filename should "
4965b40f27SMatt Arsenault         "have the name <regalloc-priority-interactive-channel-base>.in, while "
5065b40f27SMatt Arsenault         "the outgoing name should be "
5165b40f27SMatt Arsenault         "<regalloc-priority-interactive-channel-base>.out"));
5265b40f27SMatt Arsenault 
5365b40f27SMatt Arsenault using CompiledModelType = NoopSavedModelImpl;
5465b40f27SMatt Arsenault 
5565b40f27SMatt Arsenault // Options that only make sense in development mode
5665b40f27SMatt Arsenault #ifdef LLVM_HAVE_TFLITE
5765b40f27SMatt Arsenault #include "RegAllocScore.h"
5865b40f27SMatt Arsenault #include "llvm/Analysis/Utils/TFUtils.h"
5965b40f27SMatt Arsenault 
6065b40f27SMatt Arsenault static cl::opt<std::string> TrainingLog(
6165b40f27SMatt Arsenault     "regalloc-priority-training-log", cl::Hidden,
6265b40f27SMatt Arsenault     cl::desc("Training log for the register allocator priority model"));
6365b40f27SMatt Arsenault 
6465b40f27SMatt Arsenault static cl::opt<std::string> ModelUnderTraining(
6565b40f27SMatt Arsenault     "regalloc-priority-model", cl::Hidden,
6665b40f27SMatt Arsenault     cl::desc("The model being trained for register allocation priority"));
6765b40f27SMatt Arsenault 
6865b40f27SMatt Arsenault #endif // #ifdef LLVM_HAVE_TFLITE
6965b40f27SMatt Arsenault 
7065b40f27SMatt Arsenault namespace llvm {
7165b40f27SMatt Arsenault 
7265b40f27SMatt Arsenault static const std::vector<int64_t> PerLiveRangeShape{1};
7365b40f27SMatt Arsenault 
7465b40f27SMatt Arsenault #define RA_PRIORITY_FEATURES_LIST(M)                                           \
7565b40f27SMatt Arsenault   M(int64_t, li_size, PerLiveRangeShape, "size")                               \
7665b40f27SMatt Arsenault   M(int64_t, stage, PerLiveRangeShape, "stage")                                \
7765b40f27SMatt Arsenault   M(float, weight, PerLiveRangeShape, "weight")
7865b40f27SMatt Arsenault 
7965b40f27SMatt Arsenault #define DecisionName "priority"
8065b40f27SMatt Arsenault static const TensorSpec DecisionSpec =
8165b40f27SMatt Arsenault     TensorSpec::createSpec<float>(DecisionName, {1});
8265b40f27SMatt Arsenault 
8365b40f27SMatt Arsenault 
8465b40f27SMatt Arsenault // Named features index.
8565b40f27SMatt Arsenault enum FeatureIDs {
8665b40f27SMatt Arsenault #define _FEATURE_IDX(_, name, __, ___) name,
8765b40f27SMatt Arsenault   RA_PRIORITY_FEATURES_LIST(_FEATURE_IDX)
8865b40f27SMatt Arsenault #undef _FEATURE_IDX
8965b40f27SMatt Arsenault       FeatureCount
9065b40f27SMatt Arsenault };
9165b40f27SMatt Arsenault 
9265b40f27SMatt Arsenault class MLPriorityAdvisor : public RegAllocPriorityAdvisor {
9365b40f27SMatt Arsenault public:
9465b40f27SMatt Arsenault   MLPriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA,
9565b40f27SMatt Arsenault                     SlotIndexes *const Indexes, MLModelRunner *Runner);
9665b40f27SMatt Arsenault 
9765b40f27SMatt Arsenault protected:
9865b40f27SMatt Arsenault   const RegAllocPriorityAdvisor &getDefaultAdvisor() const {
9965b40f27SMatt Arsenault     return static_cast<const RegAllocPriorityAdvisor &>(DefaultAdvisor);
10065b40f27SMatt Arsenault   }
10165b40f27SMatt Arsenault 
10265b40f27SMatt Arsenault   // The assumption is that if the Runner could not be constructed, we emit-ed
10365b40f27SMatt Arsenault   // error, and we shouldn't be asking for it here.
10465b40f27SMatt Arsenault   const MLModelRunner &getRunner() const { return *Runner; }
10565b40f27SMatt Arsenault   float getPriorityImpl(const LiveInterval &LI) const;
10665b40f27SMatt Arsenault   unsigned getPriority(const LiveInterval &LI) const override;
10765b40f27SMatt Arsenault 
10865b40f27SMatt Arsenault private:
10965b40f27SMatt Arsenault   const DefaultPriorityAdvisor DefaultAdvisor;
11065b40f27SMatt Arsenault   MLModelRunner *const Runner;
11165b40f27SMatt Arsenault };
11265b40f27SMatt Arsenault 
11365b40f27SMatt Arsenault #define _DECL_FEATURES(type, name, shape, _)                                   \
11465b40f27SMatt Arsenault   TensorSpec::createSpec<type>(#name, shape),
11565b40f27SMatt Arsenault 
11665b40f27SMatt Arsenault static const std::vector<TensorSpec> InputFeatures{
11765b40f27SMatt Arsenault     {RA_PRIORITY_FEATURES_LIST(_DECL_FEATURES)},
11865b40f27SMatt Arsenault };
11965b40f27SMatt Arsenault #undef _DECL_FEATURES
12065b40f27SMatt Arsenault 
12165b40f27SMatt Arsenault // ===================================
12265b40f27SMatt Arsenault // Release (AOT) - specifics
12365b40f27SMatt Arsenault // ===================================
12465b40f27SMatt Arsenault class ReleaseModePriorityAdvisorAnalysis final
12565b40f27SMatt Arsenault     : public RegAllocPriorityAdvisorAnalysis {
12665b40f27SMatt Arsenault public:
12765b40f27SMatt Arsenault   ReleaseModePriorityAdvisorAnalysis()
12865b40f27SMatt Arsenault       : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Release) {}
12965b40f27SMatt Arsenault   // support for isa<> and dyn_cast.
13065b40f27SMatt Arsenault   static bool classof(const RegAllocPriorityAdvisorAnalysis *R) {
13165b40f27SMatt Arsenault     return R->getAdvisorMode() == AdvisorMode::Release;
13265b40f27SMatt Arsenault   }
13365b40f27SMatt Arsenault 
13465b40f27SMatt Arsenault private:
13565b40f27SMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
13665b40f27SMatt Arsenault     AU.setPreservesAll();
137*4010f894Spaperchalice     AU.addRequired<SlotIndexesWrapperPass>();
13865b40f27SMatt Arsenault     RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU);
13965b40f27SMatt Arsenault   }
14065b40f27SMatt Arsenault 
14165b40f27SMatt Arsenault   std::unique_ptr<RegAllocPriorityAdvisor>
14265b40f27SMatt Arsenault   getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
14365b40f27SMatt Arsenault     if (!Runner) {
14465b40f27SMatt Arsenault       if (InteractiveChannelBaseName.empty())
14565b40f27SMatt Arsenault         Runner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
14665b40f27SMatt Arsenault             MF.getFunction().getContext(), InputFeatures, DecisionName);
14765b40f27SMatt Arsenault       else
14865b40f27SMatt Arsenault         Runner = std::make_unique<InteractiveModelRunner>(
14965b40f27SMatt Arsenault             MF.getFunction().getContext(), InputFeatures, DecisionSpec,
15065b40f27SMatt Arsenault             InteractiveChannelBaseName + ".out",
15165b40f27SMatt Arsenault             InteractiveChannelBaseName + ".in");
15265b40f27SMatt Arsenault     }
15365b40f27SMatt Arsenault     return std::make_unique<MLPriorityAdvisor>(
154*4010f894Spaperchalice         MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI(), Runner.get());
15565b40f27SMatt Arsenault   }
15665b40f27SMatt Arsenault   std::unique_ptr<MLModelRunner> Runner;
15765b40f27SMatt Arsenault };
15865b40f27SMatt Arsenault 
15965b40f27SMatt Arsenault // ===================================
16065b40f27SMatt Arsenault // Development mode-specifics
16165b40f27SMatt Arsenault // ===================================
16265b40f27SMatt Arsenault //
16365b40f27SMatt Arsenault // Features we log
16465b40f27SMatt Arsenault #ifdef LLVM_HAVE_TFLITE
16565b40f27SMatt Arsenault static const TensorSpec Reward = TensorSpec::createSpec<float>("reward", {1});
16665b40f27SMatt Arsenault 
16765b40f27SMatt Arsenault #define _DECL_TRAIN_FEATURES(type, name, shape, _)                             \
16865b40f27SMatt Arsenault   TensorSpec::createSpec<type>(std::string("action_") + #name, shape),
16965b40f27SMatt Arsenault 
17065b40f27SMatt Arsenault static const std::vector<TensorSpec> TrainingInputFeatures{
17165b40f27SMatt Arsenault     {RA_PRIORITY_FEATURES_LIST(_DECL_TRAIN_FEATURES)
17265b40f27SMatt Arsenault          TensorSpec::createSpec<float>("action_discount", {1}),
17365b40f27SMatt Arsenault      TensorSpec::createSpec<int32_t>("action_step_type", {1}),
17465b40f27SMatt Arsenault      TensorSpec::createSpec<float>("action_reward", {1})}};
17565b40f27SMatt Arsenault #undef _DECL_TRAIN_FEATURES
17665b40f27SMatt Arsenault 
17765b40f27SMatt Arsenault class DevelopmentModePriorityAdvisor : public MLPriorityAdvisor {
17865b40f27SMatt Arsenault public:
17965b40f27SMatt Arsenault   DevelopmentModePriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA,
18065b40f27SMatt Arsenault                                  SlotIndexes *const Indexes,
18165b40f27SMatt Arsenault                                  MLModelRunner *Runner, Logger *Log)
18265b40f27SMatt Arsenault       : MLPriorityAdvisor(MF, RA, Indexes, Runner), Log(Log) {}
18365b40f27SMatt Arsenault 
18465b40f27SMatt Arsenault private:
18565b40f27SMatt Arsenault   unsigned getPriority(const LiveInterval &LI) const override;
18665b40f27SMatt Arsenault   Logger *const Log;
18765b40f27SMatt Arsenault };
18865b40f27SMatt Arsenault 
18965b40f27SMatt Arsenault class DevelopmentModePriorityAdvisorAnalysis final
19065b40f27SMatt Arsenault     : public RegAllocPriorityAdvisorAnalysis {
19165b40f27SMatt Arsenault public:
19265b40f27SMatt Arsenault   DevelopmentModePriorityAdvisorAnalysis()
19365b40f27SMatt Arsenault       : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Development) {}
19465b40f27SMatt Arsenault   // support for isa<> and dyn_cast.
19565b40f27SMatt Arsenault   static bool classof(const RegAllocPriorityAdvisorAnalysis *R) {
19665b40f27SMatt Arsenault     return R->getAdvisorMode() == AdvisorMode::Development;
19765b40f27SMatt Arsenault   }
19865b40f27SMatt Arsenault 
19965b40f27SMatt Arsenault   void logRewardIfNeeded(const MachineFunction &MF,
20065b40f27SMatt Arsenault                          llvm::function_ref<float()> GetReward) override {
20165b40f27SMatt Arsenault     if (!Log || !Log->hasAnyObservationForContext(MF.getName()))
20265b40f27SMatt Arsenault       return;
20365b40f27SMatt Arsenault     // The function pass manager would run all the function passes for a
20465b40f27SMatt Arsenault     // function, so we assume the last context belongs to this function. If
20565b40f27SMatt Arsenault     // this invariant ever changes, we can implement at that time switching
20665b40f27SMatt Arsenault     // contexts. At this point, it'd be an error
20765b40f27SMatt Arsenault     if (Log->currentContext() != MF.getName()) {
20865b40f27SMatt Arsenault       MF.getFunction().getContext().emitError(
20965b40f27SMatt Arsenault           "The training log context shouldn't have had changed.");
21065b40f27SMatt Arsenault     }
21165b40f27SMatt Arsenault     if (Log->hasObservationInProgress())
21265b40f27SMatt Arsenault       Log->logReward<float>(GetReward());
21365b40f27SMatt Arsenault   }
21465b40f27SMatt Arsenault 
21565b40f27SMatt Arsenault private:
21665b40f27SMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
21765b40f27SMatt Arsenault     AU.setPreservesAll();
218*4010f894Spaperchalice     AU.addRequired<SlotIndexesWrapperPass>();
21965b40f27SMatt Arsenault     RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU);
22065b40f27SMatt Arsenault   }
22165b40f27SMatt Arsenault 
22265b40f27SMatt Arsenault   // Save all the logs (when requested).
22365b40f27SMatt Arsenault   bool doInitialization(Module &M) override {
22465b40f27SMatt Arsenault     LLVMContext &Ctx = M.getContext();
22565b40f27SMatt Arsenault     if (ModelUnderTraining.empty() && TrainingLog.empty()) {
22665b40f27SMatt Arsenault       Ctx.emitError("Regalloc development mode should be requested with at "
22765b40f27SMatt Arsenault                     "least logging enabled and/or a training model");
22865b40f27SMatt Arsenault       return false;
22965b40f27SMatt Arsenault     }
23065b40f27SMatt Arsenault     if (ModelUnderTraining.empty())
23165b40f27SMatt Arsenault       Runner = std::make_unique<NoInferenceModelRunner>(Ctx, InputFeatures);
23265b40f27SMatt Arsenault     else
23365b40f27SMatt Arsenault       Runner = ModelUnderTrainingRunner::createAndEnsureValid(
23465b40f27SMatt Arsenault           Ctx, ModelUnderTraining, DecisionName, TrainingInputFeatures);
23565b40f27SMatt Arsenault     if (!Runner) {
23665b40f27SMatt Arsenault       Ctx.emitError("Regalloc: could not set up the model runner");
23765b40f27SMatt Arsenault       return false;
23865b40f27SMatt Arsenault     }
23965b40f27SMatt Arsenault     if (TrainingLog.empty())
24065b40f27SMatt Arsenault       return false;
24165b40f27SMatt Arsenault     std::error_code EC;
24265b40f27SMatt Arsenault     auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC);
24365b40f27SMatt Arsenault     if (EC) {
24465b40f27SMatt Arsenault       M.getContext().emitError(EC.message() + ":" + TrainingLog);
24565b40f27SMatt Arsenault       return false;
24665b40f27SMatt Arsenault     }
24765b40f27SMatt Arsenault     std::vector<TensorSpec> LFS = InputFeatures;
24865b40f27SMatt Arsenault     if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(Runner.get()))
24965b40f27SMatt Arsenault       append_range(LFS, MUTR->extraOutputsForLoggingSpecs());
25065b40f27SMatt Arsenault     // We always log the output; in particular, if we're not evaluating, we
25165b40f27SMatt Arsenault     // don't have an output spec json file. That's why we handle the
25265b40f27SMatt Arsenault     // 'normal' output separately.
25365b40f27SMatt Arsenault     LFS.push_back(DecisionSpec);
25465b40f27SMatt Arsenault 
25565b40f27SMatt Arsenault     Log = std::make_unique<Logger>(std::move(OS), LFS, Reward,
25665b40f27SMatt Arsenault                                    /*IncludeReward*/ true);
25765b40f27SMatt Arsenault     return false;
25865b40f27SMatt Arsenault   }
25965b40f27SMatt Arsenault 
26065b40f27SMatt Arsenault   std::unique_ptr<RegAllocPriorityAdvisor>
26165b40f27SMatt Arsenault   getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
26265b40f27SMatt Arsenault     if (!Runner)
26365b40f27SMatt Arsenault       return nullptr;
26465b40f27SMatt Arsenault     if (Log) {
26565b40f27SMatt Arsenault       Log->switchContext(MF.getName());
26665b40f27SMatt Arsenault     }
26765b40f27SMatt Arsenault 
26865b40f27SMatt Arsenault     return std::make_unique<DevelopmentModePriorityAdvisor>(
269*4010f894Spaperchalice         MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI(), Runner.get(),
270*4010f894Spaperchalice         Log.get());
27165b40f27SMatt Arsenault   }
27265b40f27SMatt Arsenault 
27365b40f27SMatt Arsenault   std::unique_ptr<MLModelRunner> Runner;
27465b40f27SMatt Arsenault   std::unique_ptr<Logger> Log;
27565b40f27SMatt Arsenault };
27665b40f27SMatt Arsenault #endif //#ifdef LLVM_HAVE_TFLITE
27765b40f27SMatt Arsenault 
27865b40f27SMatt Arsenault } // namespace llvm
27965b40f27SMatt Arsenault 
28065b40f27SMatt Arsenault RegAllocPriorityAdvisorAnalysis *llvm::createReleaseModePriorityAdvisor() {
28165b40f27SMatt Arsenault   return llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() ||
28265b40f27SMatt Arsenault                  !InteractiveChannelBaseName.empty()
28365b40f27SMatt Arsenault              ? new ReleaseModePriorityAdvisorAnalysis()
28465b40f27SMatt Arsenault              : nullptr;
28565b40f27SMatt Arsenault }
28665b40f27SMatt Arsenault 
28765b40f27SMatt Arsenault MLPriorityAdvisor::MLPriorityAdvisor(const MachineFunction &MF,
28865b40f27SMatt Arsenault                                      const RAGreedy &RA,
28965b40f27SMatt Arsenault                                      SlotIndexes *const Indexes,
29065b40f27SMatt Arsenault                                      MLModelRunner *Runner)
29165b40f27SMatt Arsenault     : RegAllocPriorityAdvisor(MF, RA, Indexes), DefaultAdvisor(MF, RA, Indexes),
29265b40f27SMatt Arsenault       Runner(std::move(Runner)) {
29365b40f27SMatt Arsenault   assert(this->Runner);
29465b40f27SMatt Arsenault   Runner->switchContext(MF.getName());
29565b40f27SMatt Arsenault }
29665b40f27SMatt Arsenault 
29765b40f27SMatt Arsenault float MLPriorityAdvisor::getPriorityImpl(const LiveInterval &LI) const {
29865b40f27SMatt Arsenault   const unsigned Size = LI.getSize();
29965b40f27SMatt Arsenault   LiveRangeStage Stage = RA.getExtraInfo().getStage(LI);
30065b40f27SMatt Arsenault 
30165b40f27SMatt Arsenault   *Runner->getTensor<int64_t>(0) = static_cast<int64_t>(Size);
30265b40f27SMatt Arsenault   *Runner->getTensor<int64_t>(1) = static_cast<int64_t>(Stage);
30365b40f27SMatt Arsenault   *Runner->getTensor<float>(2) = static_cast<float>(LI.weight());
30465b40f27SMatt Arsenault 
30565b40f27SMatt Arsenault   return Runner->evaluate<float>();
30665b40f27SMatt Arsenault }
30765b40f27SMatt Arsenault 
30865b40f27SMatt Arsenault unsigned MLPriorityAdvisor::getPriority(const LiveInterval &LI) const {
30965b40f27SMatt Arsenault   return static_cast<unsigned>(getPriorityImpl(LI));
31065b40f27SMatt Arsenault }
31165b40f27SMatt Arsenault 
31265b40f27SMatt Arsenault #ifdef LLVM_HAVE_TFLITE
31365b40f27SMatt Arsenault RegAllocPriorityAdvisorAnalysis *llvm::createDevelopmentModePriorityAdvisor() {
31465b40f27SMatt Arsenault   return new DevelopmentModePriorityAdvisorAnalysis();
31565b40f27SMatt Arsenault }
31665b40f27SMatt Arsenault 
31765b40f27SMatt Arsenault unsigned
31865b40f27SMatt Arsenault DevelopmentModePriorityAdvisor::getPriority(const LiveInterval &LI) const {
31965b40f27SMatt Arsenault   double Prio = 0;
32065b40f27SMatt Arsenault 
32165b40f27SMatt Arsenault   if (isa<ModelUnderTrainingRunner>(getRunner())) {
32265b40f27SMatt Arsenault     Prio = MLPriorityAdvisor::getPriorityImpl(LI);
32365b40f27SMatt Arsenault   } else {
32465b40f27SMatt Arsenault     Prio = getDefaultAdvisor().getPriority(LI);
32565b40f27SMatt Arsenault   }
32665b40f27SMatt Arsenault 
32765b40f27SMatt Arsenault   if (TrainingLog.empty())
32865b40f27SMatt Arsenault     return Prio;
32965b40f27SMatt Arsenault 
33065b40f27SMatt Arsenault   // TODO(mtrofin): when we support optional rewards, this can go away. In the
33165b40f27SMatt Arsenault   // meantime, we log the "pretend" reward (0) for the previous observation
33265b40f27SMatt Arsenault   // before starting a new one.
33365b40f27SMatt Arsenault   if (Log->hasObservationInProgress())
33465b40f27SMatt Arsenault     Log->logReward<float>(0.0);
33565b40f27SMatt Arsenault 
33665b40f27SMatt Arsenault   Log->startObservation();
33765b40f27SMatt Arsenault   size_t CurrentFeature = 0;
33865b40f27SMatt Arsenault   for (; CurrentFeature < InputFeatures.size(); ++CurrentFeature) {
33965b40f27SMatt Arsenault     Log->logTensorValue(CurrentFeature,
34065b40f27SMatt Arsenault                         reinterpret_cast<const char *>(
34165b40f27SMatt Arsenault                             getRunner().getTensorUntyped(CurrentFeature)));
34265b40f27SMatt Arsenault   }
34365b40f27SMatt Arsenault 
34465b40f27SMatt Arsenault   if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(&getRunner())) {
34565b40f27SMatt Arsenault     for (size_t I = 0; I < MUTR->extraOutputsForLoggingSpecs().size();
34665b40f27SMatt Arsenault          ++I, ++CurrentFeature)
34765b40f27SMatt Arsenault       Log->logTensorValue(
34865b40f27SMatt Arsenault           CurrentFeature,
34965b40f27SMatt Arsenault           reinterpret_cast<const char *>(MUTR->getUntypedExtraOutputValue(I)));
35065b40f27SMatt Arsenault   }
35165b40f27SMatt Arsenault 
35265b40f27SMatt Arsenault   float Ret = static_cast<float>(Prio);
35365b40f27SMatt Arsenault   Log->logTensorValue(CurrentFeature, reinterpret_cast<const char *>(&Ret));
35465b40f27SMatt Arsenault   Log->endObservation();
35565b40f27SMatt Arsenault 
35665b40f27SMatt Arsenault   return static_cast<unsigned>(Prio);
35765b40f27SMatt Arsenault }
35865b40f27SMatt Arsenault 
35965b40f27SMatt Arsenault #endif // #ifdef LLVM_HAVE_TFLITE
360