1*5f757f3fSDimitry Andric //===- MLRegAllocPriorityAdvisor.cpp - ML priority advisor-----------------===// 2*5f757f3fSDimitry Andric // 3*5f757f3fSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*5f757f3fSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*5f757f3fSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*5f757f3fSDimitry Andric // 7*5f757f3fSDimitry Andric //===----------------------------------------------------------------------===// 8*5f757f3fSDimitry Andric // 9*5f757f3fSDimitry Andric // Implementation of the ML priority advisor and reward injection pass 10*5f757f3fSDimitry Andric // 11*5f757f3fSDimitry Andric //===----------------------------------------------------------------------===// 12*5f757f3fSDimitry Andric 13*5f757f3fSDimitry Andric #include "AllocationOrder.h" 14*5f757f3fSDimitry Andric #include "RegAllocGreedy.h" 15*5f757f3fSDimitry Andric #include "RegAllocPriorityAdvisor.h" 16*5f757f3fSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h" 17*5f757f3fSDimitry Andric #include "llvm/Analysis/InteractiveModelRunner.h" 18*5f757f3fSDimitry Andric #include "llvm/Analysis/MLModelRunner.h" 19*5f757f3fSDimitry Andric #include "llvm/Analysis/ReleaseModeModelRunner.h" 20*5f757f3fSDimitry Andric #include "llvm/Analysis/TensorSpec.h" 21*5f757f3fSDimitry Andric #include "llvm/CodeGen/CalcSpillWeights.h" 22*5f757f3fSDimitry Andric #include "llvm/CodeGen/LiveRegMatrix.h" 23*5f757f3fSDimitry Andric #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" 24*5f757f3fSDimitry Andric #include "llvm/CodeGen/MachineFunction.h" 25*5f757f3fSDimitry Andric #include "llvm/CodeGen/MachineLoopInfo.h" 26*5f757f3fSDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h" 27*5f757f3fSDimitry Andric #include "llvm/CodeGen/Passes.h" 28*5f757f3fSDimitry Andric #include "llvm/CodeGen/RegisterClassInfo.h" 29*5f757f3fSDimitry Andric #include "llvm/CodeGen/SlotIndexes.h" 30*5f757f3fSDimitry Andric #include "llvm/CodeGen/VirtRegMap.h" 31*5f757f3fSDimitry Andric #include "llvm/InitializePasses.h" 32*5f757f3fSDimitry Andric #include "llvm/Pass.h" 33*5f757f3fSDimitry Andric #include "llvm/PassRegistry.h" 34*5f757f3fSDimitry Andric #include "llvm/Support/CommandLine.h" 35*5f757f3fSDimitry Andric 36*5f757f3fSDimitry Andric #if defined(LLVM_HAVE_TFLITE) 37*5f757f3fSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h" 38*5f757f3fSDimitry Andric #include "llvm/Analysis/NoInferenceModelRunner.h" 39*5f757f3fSDimitry Andric #include "llvm/Analysis/Utils/TrainingLogger.h" 40*5f757f3fSDimitry Andric #endif 41*5f757f3fSDimitry Andric 42*5f757f3fSDimitry Andric using namespace llvm; 43*5f757f3fSDimitry Andric 44*5f757f3fSDimitry Andric static cl::opt<std::string> InteractiveChannelBaseName( 45*5f757f3fSDimitry Andric "regalloc-priority-interactive-channel-base", cl::Hidden, 46*5f757f3fSDimitry Andric cl::desc( 47*5f757f3fSDimitry Andric "Base file path for the interactive mode. The incoming filename should " 48*5f757f3fSDimitry Andric "have the name <regalloc-priority-interactive-channel-base>.in, while " 49*5f757f3fSDimitry Andric "the outgoing name should be " 50*5f757f3fSDimitry Andric "<regalloc-priority-interactive-channel-base>.out")); 51*5f757f3fSDimitry Andric 52*5f757f3fSDimitry Andric using CompiledModelType = NoopSavedModelImpl; 53*5f757f3fSDimitry Andric 54*5f757f3fSDimitry Andric // Options that only make sense in development mode 55*5f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE 56*5f757f3fSDimitry Andric #include "RegAllocScore.h" 57*5f757f3fSDimitry Andric #include "llvm/Analysis/Utils/TFUtils.h" 58*5f757f3fSDimitry Andric 59*5f757f3fSDimitry Andric static cl::opt<std::string> TrainingLog( 60*5f757f3fSDimitry Andric "regalloc-priority-training-log", cl::Hidden, 61*5f757f3fSDimitry Andric cl::desc("Training log for the register allocator priority model")); 62*5f757f3fSDimitry Andric 63*5f757f3fSDimitry Andric static cl::opt<std::string> ModelUnderTraining( 64*5f757f3fSDimitry Andric "regalloc-priority-model", cl::Hidden, 65*5f757f3fSDimitry Andric cl::desc("The model being trained for register allocation priority")); 66*5f757f3fSDimitry Andric 67*5f757f3fSDimitry Andric #endif // #ifdef LLVM_HAVE_TFLITE 68*5f757f3fSDimitry Andric 69*5f757f3fSDimitry Andric namespace llvm { 70*5f757f3fSDimitry Andric 71*5f757f3fSDimitry Andric static const std::vector<int64_t> PerLiveRangeShape{1}; 72*5f757f3fSDimitry Andric 73*5f757f3fSDimitry Andric #define RA_PRIORITY_FEATURES_LIST(M) \ 74*5f757f3fSDimitry Andric M(int64_t, li_size, PerLiveRangeShape, "size") \ 75*5f757f3fSDimitry Andric M(int64_t, stage, PerLiveRangeShape, "stage") \ 76*5f757f3fSDimitry Andric M(float, weight, PerLiveRangeShape, "weight") 77*5f757f3fSDimitry Andric 78*5f757f3fSDimitry Andric #define DecisionName "priority" 79*5f757f3fSDimitry Andric static const TensorSpec DecisionSpec = 80*5f757f3fSDimitry Andric TensorSpec::createSpec<float>(DecisionName, {1}); 81*5f757f3fSDimitry Andric 82*5f757f3fSDimitry Andric 83*5f757f3fSDimitry Andric // Named features index. 84*5f757f3fSDimitry Andric enum FeatureIDs { 85*5f757f3fSDimitry Andric #define _FEATURE_IDX(_, name, __, ___) name, 86*5f757f3fSDimitry Andric RA_PRIORITY_FEATURES_LIST(_FEATURE_IDX) 87*5f757f3fSDimitry Andric #undef _FEATURE_IDX 88*5f757f3fSDimitry Andric FeatureCount 89*5f757f3fSDimitry Andric }; 90*5f757f3fSDimitry Andric 91*5f757f3fSDimitry Andric class MLPriorityAdvisor : public RegAllocPriorityAdvisor { 92*5f757f3fSDimitry Andric public: 93*5f757f3fSDimitry Andric MLPriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA, 94*5f757f3fSDimitry Andric SlotIndexes *const Indexes, MLModelRunner *Runner); 95*5f757f3fSDimitry Andric 96*5f757f3fSDimitry Andric protected: 97*5f757f3fSDimitry Andric const RegAllocPriorityAdvisor &getDefaultAdvisor() const { 98*5f757f3fSDimitry Andric return static_cast<const RegAllocPriorityAdvisor &>(DefaultAdvisor); 99*5f757f3fSDimitry Andric } 100*5f757f3fSDimitry Andric 101*5f757f3fSDimitry Andric // The assumption is that if the Runner could not be constructed, we emit-ed 102*5f757f3fSDimitry Andric // error, and we shouldn't be asking for it here. 103*5f757f3fSDimitry Andric const MLModelRunner &getRunner() const { return *Runner; } 104*5f757f3fSDimitry Andric float getPriorityImpl(const LiveInterval &LI) const; 105*5f757f3fSDimitry Andric unsigned getPriority(const LiveInterval &LI) const override; 106*5f757f3fSDimitry Andric 107*5f757f3fSDimitry Andric private: 108*5f757f3fSDimitry Andric const DefaultPriorityAdvisor DefaultAdvisor; 109*5f757f3fSDimitry Andric MLModelRunner *const Runner; 110*5f757f3fSDimitry Andric }; 111*5f757f3fSDimitry Andric 112*5f757f3fSDimitry Andric #define _DECL_FEATURES(type, name, shape, _) \ 113*5f757f3fSDimitry Andric TensorSpec::createSpec<type>(#name, shape), 114*5f757f3fSDimitry Andric 115*5f757f3fSDimitry Andric static const std::vector<TensorSpec> InputFeatures{ 116*5f757f3fSDimitry Andric {RA_PRIORITY_FEATURES_LIST(_DECL_FEATURES)}, 117*5f757f3fSDimitry Andric }; 118*5f757f3fSDimitry Andric #undef _DECL_FEATURES 119*5f757f3fSDimitry Andric 120*5f757f3fSDimitry Andric // =================================== 121*5f757f3fSDimitry Andric // Release (AOT) - specifics 122*5f757f3fSDimitry Andric // =================================== 123*5f757f3fSDimitry Andric class ReleaseModePriorityAdvisorAnalysis final 124*5f757f3fSDimitry Andric : public RegAllocPriorityAdvisorAnalysis { 125*5f757f3fSDimitry Andric public: 126*5f757f3fSDimitry Andric ReleaseModePriorityAdvisorAnalysis() 127*5f757f3fSDimitry Andric : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Release) {} 128*5f757f3fSDimitry Andric // support for isa<> and dyn_cast. 129*5f757f3fSDimitry Andric static bool classof(const RegAllocPriorityAdvisorAnalysis *R) { 130*5f757f3fSDimitry Andric return R->getAdvisorMode() == AdvisorMode::Release; 131*5f757f3fSDimitry Andric } 132*5f757f3fSDimitry Andric 133*5f757f3fSDimitry Andric private: 134*5f757f3fSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 135*5f757f3fSDimitry Andric AU.setPreservesAll(); 136*5f757f3fSDimitry Andric AU.addRequired<SlotIndexes>(); 137*5f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU); 138*5f757f3fSDimitry Andric } 139*5f757f3fSDimitry Andric 140*5f757f3fSDimitry Andric std::unique_ptr<RegAllocPriorityAdvisor> 141*5f757f3fSDimitry Andric getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override { 142*5f757f3fSDimitry Andric if (!Runner) { 143*5f757f3fSDimitry Andric if (InteractiveChannelBaseName.empty()) 144*5f757f3fSDimitry Andric Runner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>( 145*5f757f3fSDimitry Andric MF.getFunction().getContext(), InputFeatures, DecisionName); 146*5f757f3fSDimitry Andric else 147*5f757f3fSDimitry Andric Runner = std::make_unique<InteractiveModelRunner>( 148*5f757f3fSDimitry Andric MF.getFunction().getContext(), InputFeatures, DecisionSpec, 149*5f757f3fSDimitry Andric InteractiveChannelBaseName + ".out", 150*5f757f3fSDimitry Andric InteractiveChannelBaseName + ".in"); 151*5f757f3fSDimitry Andric } 152*5f757f3fSDimitry Andric return std::make_unique<MLPriorityAdvisor>( 153*5f757f3fSDimitry Andric MF, RA, &getAnalysis<SlotIndexes>(), Runner.get()); 154*5f757f3fSDimitry Andric } 155*5f757f3fSDimitry Andric std::unique_ptr<MLModelRunner> Runner; 156*5f757f3fSDimitry Andric }; 157*5f757f3fSDimitry Andric 158*5f757f3fSDimitry Andric // =================================== 159*5f757f3fSDimitry Andric // Development mode-specifics 160*5f757f3fSDimitry Andric // =================================== 161*5f757f3fSDimitry Andric // 162*5f757f3fSDimitry Andric // Features we log 163*5f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE 164*5f757f3fSDimitry Andric static const TensorSpec Reward = TensorSpec::createSpec<float>("reward", {1}); 165*5f757f3fSDimitry Andric 166*5f757f3fSDimitry Andric #define _DECL_TRAIN_FEATURES(type, name, shape, _) \ 167*5f757f3fSDimitry Andric TensorSpec::createSpec<type>(std::string("action_") + #name, shape), 168*5f757f3fSDimitry Andric 169*5f757f3fSDimitry Andric static const std::vector<TensorSpec> TrainingInputFeatures{ 170*5f757f3fSDimitry Andric {RA_PRIORITY_FEATURES_LIST(_DECL_TRAIN_FEATURES) 171*5f757f3fSDimitry Andric TensorSpec::createSpec<float>("action_discount", {1}), 172*5f757f3fSDimitry Andric TensorSpec::createSpec<int32_t>("action_step_type", {1}), 173*5f757f3fSDimitry Andric TensorSpec::createSpec<float>("action_reward", {1})}}; 174*5f757f3fSDimitry Andric #undef _DECL_TRAIN_FEATURES 175*5f757f3fSDimitry Andric 176*5f757f3fSDimitry Andric class DevelopmentModePriorityAdvisor : public MLPriorityAdvisor { 177*5f757f3fSDimitry Andric public: 178*5f757f3fSDimitry Andric DevelopmentModePriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA, 179*5f757f3fSDimitry Andric SlotIndexes *const Indexes, 180*5f757f3fSDimitry Andric MLModelRunner *Runner, Logger *Log) 181*5f757f3fSDimitry Andric : MLPriorityAdvisor(MF, RA, Indexes, Runner), Log(Log) {} 182*5f757f3fSDimitry Andric 183*5f757f3fSDimitry Andric private: 184*5f757f3fSDimitry Andric unsigned getPriority(const LiveInterval &LI) const override; 185*5f757f3fSDimitry Andric Logger *const Log; 186*5f757f3fSDimitry Andric }; 187*5f757f3fSDimitry Andric 188*5f757f3fSDimitry Andric class DevelopmentModePriorityAdvisorAnalysis final 189*5f757f3fSDimitry Andric : public RegAllocPriorityAdvisorAnalysis { 190*5f757f3fSDimitry Andric public: 191*5f757f3fSDimitry Andric DevelopmentModePriorityAdvisorAnalysis() 192*5f757f3fSDimitry Andric : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Development) {} 193*5f757f3fSDimitry Andric // support for isa<> and dyn_cast. 194*5f757f3fSDimitry Andric static bool classof(const RegAllocPriorityAdvisorAnalysis *R) { 195*5f757f3fSDimitry Andric return R->getAdvisorMode() == AdvisorMode::Development; 196*5f757f3fSDimitry Andric } 197*5f757f3fSDimitry Andric 198*5f757f3fSDimitry Andric void logRewardIfNeeded(const MachineFunction &MF, 199*5f757f3fSDimitry Andric llvm::function_ref<float()> GetReward) override { 200*5f757f3fSDimitry Andric if (!Log || !Log->hasAnyObservationForContext(MF.getName())) 201*5f757f3fSDimitry Andric return; 202*5f757f3fSDimitry Andric // The function pass manager would run all the function passes for a 203*5f757f3fSDimitry Andric // function, so we assume the last context belongs to this function. If 204*5f757f3fSDimitry Andric // this invariant ever changes, we can implement at that time switching 205*5f757f3fSDimitry Andric // contexts. At this point, it'd be an error 206*5f757f3fSDimitry Andric if (Log->currentContext() != MF.getName()) { 207*5f757f3fSDimitry Andric MF.getFunction().getContext().emitError( 208*5f757f3fSDimitry Andric "The training log context shouldn't have had changed."); 209*5f757f3fSDimitry Andric } 210*5f757f3fSDimitry Andric if (Log->hasObservationInProgress()) 211*5f757f3fSDimitry Andric Log->logReward<float>(GetReward()); 212*5f757f3fSDimitry Andric } 213*5f757f3fSDimitry Andric 214*5f757f3fSDimitry Andric private: 215*5f757f3fSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 216*5f757f3fSDimitry Andric AU.setPreservesAll(); 217*5f757f3fSDimitry Andric AU.addRequired<SlotIndexes>(); 218*5f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU); 219*5f757f3fSDimitry Andric } 220*5f757f3fSDimitry Andric 221*5f757f3fSDimitry Andric // Save all the logs (when requested). 222*5f757f3fSDimitry Andric bool doInitialization(Module &M) override { 223*5f757f3fSDimitry Andric LLVMContext &Ctx = M.getContext(); 224*5f757f3fSDimitry Andric if (ModelUnderTraining.empty() && TrainingLog.empty()) { 225*5f757f3fSDimitry Andric Ctx.emitError("Regalloc development mode should be requested with at " 226*5f757f3fSDimitry Andric "least logging enabled and/or a training model"); 227*5f757f3fSDimitry Andric return false; 228*5f757f3fSDimitry Andric } 229*5f757f3fSDimitry Andric if (ModelUnderTraining.empty()) 230*5f757f3fSDimitry Andric Runner = std::make_unique<NoInferenceModelRunner>(Ctx, InputFeatures); 231*5f757f3fSDimitry Andric else 232*5f757f3fSDimitry Andric Runner = ModelUnderTrainingRunner::createAndEnsureValid( 233*5f757f3fSDimitry Andric Ctx, ModelUnderTraining, DecisionName, TrainingInputFeatures); 234*5f757f3fSDimitry Andric if (!Runner) { 235*5f757f3fSDimitry Andric Ctx.emitError("Regalloc: could not set up the model runner"); 236*5f757f3fSDimitry Andric return false; 237*5f757f3fSDimitry Andric } 238*5f757f3fSDimitry Andric if (TrainingLog.empty()) 239*5f757f3fSDimitry Andric return false; 240*5f757f3fSDimitry Andric std::error_code EC; 241*5f757f3fSDimitry Andric auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC); 242*5f757f3fSDimitry Andric if (EC) { 243*5f757f3fSDimitry Andric M.getContext().emitError(EC.message() + ":" + TrainingLog); 244*5f757f3fSDimitry Andric return false; 245*5f757f3fSDimitry Andric } 246*5f757f3fSDimitry Andric std::vector<TensorSpec> LFS = InputFeatures; 247*5f757f3fSDimitry Andric if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(Runner.get())) 248*5f757f3fSDimitry Andric append_range(LFS, MUTR->extraOutputsForLoggingSpecs()); 249*5f757f3fSDimitry Andric // We always log the output; in particular, if we're not evaluating, we 250*5f757f3fSDimitry Andric // don't have an output spec json file. That's why we handle the 251*5f757f3fSDimitry Andric // 'normal' output separately. 252*5f757f3fSDimitry Andric LFS.push_back(DecisionSpec); 253*5f757f3fSDimitry Andric 254*5f757f3fSDimitry Andric Log = std::make_unique<Logger>(std::move(OS), LFS, Reward, 255*5f757f3fSDimitry Andric /*IncludeReward*/ true); 256*5f757f3fSDimitry Andric return false; 257*5f757f3fSDimitry Andric } 258*5f757f3fSDimitry Andric 259*5f757f3fSDimitry Andric std::unique_ptr<RegAllocPriorityAdvisor> 260*5f757f3fSDimitry Andric getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override { 261*5f757f3fSDimitry Andric if (!Runner) 262*5f757f3fSDimitry Andric return nullptr; 263*5f757f3fSDimitry Andric if (Log) { 264*5f757f3fSDimitry Andric Log->switchContext(MF.getName()); 265*5f757f3fSDimitry Andric } 266*5f757f3fSDimitry Andric 267*5f757f3fSDimitry Andric return std::make_unique<DevelopmentModePriorityAdvisor>( 268*5f757f3fSDimitry Andric MF, RA, &getAnalysis<SlotIndexes>(), Runner.get(), Log.get()); 269*5f757f3fSDimitry Andric } 270*5f757f3fSDimitry Andric 271*5f757f3fSDimitry Andric std::unique_ptr<MLModelRunner> Runner; 272*5f757f3fSDimitry Andric std::unique_ptr<Logger> Log; 273*5f757f3fSDimitry Andric }; 274*5f757f3fSDimitry Andric #endif //#ifdef LLVM_HAVE_TFLITE 275*5f757f3fSDimitry Andric 276*5f757f3fSDimitry Andric } // namespace llvm 277*5f757f3fSDimitry Andric 278*5f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis *llvm::createReleaseModePriorityAdvisor() { 279*5f757f3fSDimitry Andric return llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() || 280*5f757f3fSDimitry Andric !InteractiveChannelBaseName.empty() 281*5f757f3fSDimitry Andric ? new ReleaseModePriorityAdvisorAnalysis() 282*5f757f3fSDimitry Andric : nullptr; 283*5f757f3fSDimitry Andric } 284*5f757f3fSDimitry Andric 285*5f757f3fSDimitry Andric MLPriorityAdvisor::MLPriorityAdvisor(const MachineFunction &MF, 286*5f757f3fSDimitry Andric const RAGreedy &RA, 287*5f757f3fSDimitry Andric SlotIndexes *const Indexes, 288*5f757f3fSDimitry Andric MLModelRunner *Runner) 289*5f757f3fSDimitry Andric : RegAllocPriorityAdvisor(MF, RA, Indexes), DefaultAdvisor(MF, RA, Indexes), 290*5f757f3fSDimitry Andric Runner(std::move(Runner)) { 291*5f757f3fSDimitry Andric assert(this->Runner); 292*5f757f3fSDimitry Andric Runner->switchContext(MF.getName()); 293*5f757f3fSDimitry Andric } 294*5f757f3fSDimitry Andric 295*5f757f3fSDimitry Andric float MLPriorityAdvisor::getPriorityImpl(const LiveInterval &LI) const { 296*5f757f3fSDimitry Andric const unsigned Size = LI.getSize(); 297*5f757f3fSDimitry Andric LiveRangeStage Stage = RA.getExtraInfo().getStage(LI); 298*5f757f3fSDimitry Andric 299*5f757f3fSDimitry Andric *Runner->getTensor<int64_t>(0) = static_cast<int64_t>(Size); 300*5f757f3fSDimitry Andric *Runner->getTensor<int64_t>(1) = static_cast<int64_t>(Stage); 301*5f757f3fSDimitry Andric *Runner->getTensor<float>(2) = static_cast<float>(LI.weight()); 302*5f757f3fSDimitry Andric 303*5f757f3fSDimitry Andric return Runner->evaluate<float>(); 304*5f757f3fSDimitry Andric } 305*5f757f3fSDimitry Andric 306*5f757f3fSDimitry Andric unsigned MLPriorityAdvisor::getPriority(const LiveInterval &LI) const { 307*5f757f3fSDimitry Andric return static_cast<unsigned>(getPriorityImpl(LI)); 308*5f757f3fSDimitry Andric } 309*5f757f3fSDimitry Andric 310*5f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE 311*5f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis *llvm::createDevelopmentModePriorityAdvisor() { 312*5f757f3fSDimitry Andric return new DevelopmentModePriorityAdvisorAnalysis(); 313*5f757f3fSDimitry Andric } 314*5f757f3fSDimitry Andric 315*5f757f3fSDimitry Andric unsigned 316*5f757f3fSDimitry Andric DevelopmentModePriorityAdvisor::getPriority(const LiveInterval &LI) const { 317*5f757f3fSDimitry Andric double Prio = 0; 318*5f757f3fSDimitry Andric 319*5f757f3fSDimitry Andric if (isa<ModelUnderTrainingRunner>(getRunner())) { 320*5f757f3fSDimitry Andric Prio = MLPriorityAdvisor::getPriorityImpl(LI); 321*5f757f3fSDimitry Andric } else { 322*5f757f3fSDimitry Andric Prio = getDefaultAdvisor().getPriority(LI); 323*5f757f3fSDimitry Andric } 324*5f757f3fSDimitry Andric 325*5f757f3fSDimitry Andric if (TrainingLog.empty()) 326*5f757f3fSDimitry Andric return Prio; 327*5f757f3fSDimitry Andric 328*5f757f3fSDimitry Andric // TODO(mtrofin): when we support optional rewards, this can go away. In the 329*5f757f3fSDimitry Andric // meantime, we log the "pretend" reward (0) for the previous observation 330*5f757f3fSDimitry Andric // before starting a new one. 331*5f757f3fSDimitry Andric if (Log->hasObservationInProgress()) 332*5f757f3fSDimitry Andric Log->logReward<float>(0.0); 333*5f757f3fSDimitry Andric 334*5f757f3fSDimitry Andric Log->startObservation(); 335*5f757f3fSDimitry Andric size_t CurrentFeature = 0; 336*5f757f3fSDimitry Andric for (; CurrentFeature < InputFeatures.size(); ++CurrentFeature) { 337*5f757f3fSDimitry Andric Log->logTensorValue(CurrentFeature, 338*5f757f3fSDimitry Andric reinterpret_cast<const char *>( 339*5f757f3fSDimitry Andric getRunner().getTensorUntyped(CurrentFeature))); 340*5f757f3fSDimitry Andric } 341*5f757f3fSDimitry Andric 342*5f757f3fSDimitry Andric if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(&getRunner())) { 343*5f757f3fSDimitry Andric for (size_t I = 0; I < MUTR->extraOutputsForLoggingSpecs().size(); 344*5f757f3fSDimitry Andric ++I, ++CurrentFeature) 345*5f757f3fSDimitry Andric Log->logTensorValue( 346*5f757f3fSDimitry Andric CurrentFeature, 347*5f757f3fSDimitry Andric reinterpret_cast<const char *>(MUTR->getUntypedExtraOutputValue(I))); 348*5f757f3fSDimitry Andric } 349*5f757f3fSDimitry Andric 350*5f757f3fSDimitry Andric float Ret = static_cast<float>(Prio); 351*5f757f3fSDimitry Andric Log->logTensorValue(CurrentFeature, reinterpret_cast<const char *>(&Ret)); 352*5f757f3fSDimitry Andric Log->endObservation(); 353*5f757f3fSDimitry Andric 354*5f757f3fSDimitry Andric return static_cast<unsigned>(Prio); 355*5f757f3fSDimitry Andric } 356*5f757f3fSDimitry Andric 357*5f757f3fSDimitry Andric #endif // #ifdef LLVM_HAVE_TFLITE 358