165b40f27SMatt Arsenault //===- MLRegAllocPriorityAdvisor.cpp - ML priority advisor-----------------===// 265b40f27SMatt Arsenault // 365b40f27SMatt Arsenault // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 465b40f27SMatt Arsenault // See https://llvm.org/LICENSE.txt for license information. 565b40f27SMatt Arsenault // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 665b40f27SMatt Arsenault // 765b40f27SMatt Arsenault //===----------------------------------------------------------------------===// 865b40f27SMatt Arsenault // 965b40f27SMatt Arsenault // Implementation of the ML priority advisor and reward injection pass 1065b40f27SMatt Arsenault // 1165b40f27SMatt Arsenault //===----------------------------------------------------------------------===// 1265b40f27SMatt Arsenault 1365b40f27SMatt Arsenault #include "AllocationOrder.h" 1465b40f27SMatt Arsenault #include "RegAllocGreedy.h" 1565b40f27SMatt Arsenault #include "RegAllocPriorityAdvisor.h" 1665b40f27SMatt Arsenault #include "llvm/Analysis/AliasAnalysis.h" 1765b40f27SMatt Arsenault #include "llvm/Analysis/InteractiveModelRunner.h" 1865b40f27SMatt Arsenault #include "llvm/Analysis/MLModelRunner.h" 1965b40f27SMatt Arsenault #include "llvm/Analysis/ReleaseModeModelRunner.h" 2065b40f27SMatt Arsenault #include "llvm/Analysis/TensorSpec.h" 2165b40f27SMatt Arsenault #include "llvm/CodeGen/CalcSpillWeights.h" 2265b40f27SMatt Arsenault #include "llvm/CodeGen/LiveRegMatrix.h" 2365b40f27SMatt Arsenault #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" 2465b40f27SMatt Arsenault #include "llvm/CodeGen/MachineFunction.h" 2565b40f27SMatt Arsenault #include "llvm/CodeGen/MachineLoopInfo.h" 2665b40f27SMatt Arsenault #include "llvm/CodeGen/MachineRegisterInfo.h" 2765b40f27SMatt Arsenault #include "llvm/CodeGen/Passes.h" 2865b40f27SMatt Arsenault #include "llvm/CodeGen/RegisterClassInfo.h" 2965b40f27SMatt Arsenault #include "llvm/CodeGen/SlotIndexes.h" 3065b40f27SMatt Arsenault #include "llvm/CodeGen/VirtRegMap.h" 3165b40f27SMatt Arsenault #include "llvm/InitializePasses.h" 3265b40f27SMatt Arsenault #include "llvm/Pass.h" 3365b40f27SMatt Arsenault #include "llvm/PassRegistry.h" 3465b40f27SMatt Arsenault #include "llvm/Support/CommandLine.h" 3565b40f27SMatt Arsenault 3665b40f27SMatt Arsenault #if defined(LLVM_HAVE_TFLITE) 3765b40f27SMatt Arsenault #include "llvm/Analysis/ModelUnderTrainingRunner.h" 3865b40f27SMatt Arsenault #include "llvm/Analysis/NoInferenceModelRunner.h" 3965b40f27SMatt Arsenault #include "llvm/Analysis/Utils/TrainingLogger.h" 400606c64dSNikita Popov #include "llvm/IR/Module.h" 4165b40f27SMatt Arsenault #endif 4265b40f27SMatt Arsenault 4365b40f27SMatt Arsenault using namespace llvm; 4465b40f27SMatt Arsenault 4565b40f27SMatt Arsenault static cl::opt<std::string> InteractiveChannelBaseName( 4665b40f27SMatt Arsenault "regalloc-priority-interactive-channel-base", cl::Hidden, 4765b40f27SMatt Arsenault cl::desc( 4865b40f27SMatt Arsenault "Base file path for the interactive mode. The incoming filename should " 4965b40f27SMatt Arsenault "have the name <regalloc-priority-interactive-channel-base>.in, while " 5065b40f27SMatt Arsenault "the outgoing name should be " 5165b40f27SMatt Arsenault "<regalloc-priority-interactive-channel-base>.out")); 5265b40f27SMatt Arsenault 5365b40f27SMatt Arsenault using CompiledModelType = NoopSavedModelImpl; 5465b40f27SMatt Arsenault 5565b40f27SMatt Arsenault // Options that only make sense in development mode 5665b40f27SMatt Arsenault #ifdef LLVM_HAVE_TFLITE 5765b40f27SMatt Arsenault #include "RegAllocScore.h" 5865b40f27SMatt Arsenault #include "llvm/Analysis/Utils/TFUtils.h" 5965b40f27SMatt Arsenault 6065b40f27SMatt Arsenault static cl::opt<std::string> TrainingLog( 6165b40f27SMatt Arsenault "regalloc-priority-training-log", cl::Hidden, 6265b40f27SMatt Arsenault cl::desc("Training log for the register allocator priority model")); 6365b40f27SMatt Arsenault 6465b40f27SMatt Arsenault static cl::opt<std::string> ModelUnderTraining( 6565b40f27SMatt Arsenault "regalloc-priority-model", cl::Hidden, 6665b40f27SMatt Arsenault cl::desc("The model being trained for register allocation priority")); 6765b40f27SMatt Arsenault 6865b40f27SMatt Arsenault #endif // #ifdef LLVM_HAVE_TFLITE 6965b40f27SMatt Arsenault 7065b40f27SMatt Arsenault namespace llvm { 7165b40f27SMatt Arsenault 7265b40f27SMatt Arsenault static const std::vector<int64_t> PerLiveRangeShape{1}; 7365b40f27SMatt Arsenault 7465b40f27SMatt Arsenault #define RA_PRIORITY_FEATURES_LIST(M) \ 7565b40f27SMatt Arsenault M(int64_t, li_size, PerLiveRangeShape, "size") \ 7665b40f27SMatt Arsenault M(int64_t, stage, PerLiveRangeShape, "stage") \ 7765b40f27SMatt Arsenault M(float, weight, PerLiveRangeShape, "weight") 7865b40f27SMatt Arsenault 7965b40f27SMatt Arsenault #define DecisionName "priority" 8065b40f27SMatt Arsenault static const TensorSpec DecisionSpec = 8165b40f27SMatt Arsenault TensorSpec::createSpec<float>(DecisionName, {1}); 8265b40f27SMatt Arsenault 8365b40f27SMatt Arsenault 8465b40f27SMatt Arsenault // Named features index. 8565b40f27SMatt Arsenault enum FeatureIDs { 8665b40f27SMatt Arsenault #define _FEATURE_IDX(_, name, __, ___) name, 8765b40f27SMatt Arsenault RA_PRIORITY_FEATURES_LIST(_FEATURE_IDX) 8865b40f27SMatt Arsenault #undef _FEATURE_IDX 8965b40f27SMatt Arsenault FeatureCount 9065b40f27SMatt Arsenault }; 9165b40f27SMatt Arsenault 9265b40f27SMatt Arsenault class MLPriorityAdvisor : public RegAllocPriorityAdvisor { 9365b40f27SMatt Arsenault public: 9465b40f27SMatt Arsenault MLPriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA, 9565b40f27SMatt Arsenault SlotIndexes *const Indexes, MLModelRunner *Runner); 9665b40f27SMatt Arsenault 9765b40f27SMatt Arsenault protected: 9865b40f27SMatt Arsenault const RegAllocPriorityAdvisor &getDefaultAdvisor() const { 9965b40f27SMatt Arsenault return static_cast<const RegAllocPriorityAdvisor &>(DefaultAdvisor); 10065b40f27SMatt Arsenault } 10165b40f27SMatt Arsenault 10265b40f27SMatt Arsenault // The assumption is that if the Runner could not be constructed, we emit-ed 10365b40f27SMatt Arsenault // error, and we shouldn't be asking for it here. 10465b40f27SMatt Arsenault const MLModelRunner &getRunner() const { return *Runner; } 10565b40f27SMatt Arsenault float getPriorityImpl(const LiveInterval &LI) const; 10665b40f27SMatt Arsenault unsigned getPriority(const LiveInterval &LI) const override; 10765b40f27SMatt Arsenault 10865b40f27SMatt Arsenault private: 10965b40f27SMatt Arsenault const DefaultPriorityAdvisor DefaultAdvisor; 11065b40f27SMatt Arsenault MLModelRunner *const Runner; 11165b40f27SMatt Arsenault }; 11265b40f27SMatt Arsenault 11365b40f27SMatt Arsenault #define _DECL_FEATURES(type, name, shape, _) \ 11465b40f27SMatt Arsenault TensorSpec::createSpec<type>(#name, shape), 11565b40f27SMatt Arsenault 11665b40f27SMatt Arsenault static const std::vector<TensorSpec> InputFeatures{ 11765b40f27SMatt Arsenault {RA_PRIORITY_FEATURES_LIST(_DECL_FEATURES)}, 11865b40f27SMatt Arsenault }; 11965b40f27SMatt Arsenault #undef _DECL_FEATURES 12065b40f27SMatt Arsenault 12165b40f27SMatt Arsenault // =================================== 12265b40f27SMatt Arsenault // Release (AOT) - specifics 12365b40f27SMatt Arsenault // =================================== 12465b40f27SMatt Arsenault class ReleaseModePriorityAdvisorAnalysis final 12565b40f27SMatt Arsenault : public RegAllocPriorityAdvisorAnalysis { 12665b40f27SMatt Arsenault public: 12765b40f27SMatt Arsenault ReleaseModePriorityAdvisorAnalysis() 12865b40f27SMatt Arsenault : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Release) {} 12965b40f27SMatt Arsenault // support for isa<> and dyn_cast. 13065b40f27SMatt Arsenault static bool classof(const RegAllocPriorityAdvisorAnalysis *R) { 13165b40f27SMatt Arsenault return R->getAdvisorMode() == AdvisorMode::Release; 13265b40f27SMatt Arsenault } 13365b40f27SMatt Arsenault 13465b40f27SMatt Arsenault private: 13565b40f27SMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override { 13665b40f27SMatt Arsenault AU.setPreservesAll(); 137*4010f894Spaperchalice AU.addRequired<SlotIndexesWrapperPass>(); 13865b40f27SMatt Arsenault RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU); 13965b40f27SMatt Arsenault } 14065b40f27SMatt Arsenault 14165b40f27SMatt Arsenault std::unique_ptr<RegAllocPriorityAdvisor> 14265b40f27SMatt Arsenault getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override { 14365b40f27SMatt Arsenault if (!Runner) { 14465b40f27SMatt Arsenault if (InteractiveChannelBaseName.empty()) 14565b40f27SMatt Arsenault Runner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>( 14665b40f27SMatt Arsenault MF.getFunction().getContext(), InputFeatures, DecisionName); 14765b40f27SMatt Arsenault else 14865b40f27SMatt Arsenault Runner = std::make_unique<InteractiveModelRunner>( 14965b40f27SMatt Arsenault MF.getFunction().getContext(), InputFeatures, DecisionSpec, 15065b40f27SMatt Arsenault InteractiveChannelBaseName + ".out", 15165b40f27SMatt Arsenault InteractiveChannelBaseName + ".in"); 15265b40f27SMatt Arsenault } 15365b40f27SMatt Arsenault return std::make_unique<MLPriorityAdvisor>( 154*4010f894Spaperchalice MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI(), Runner.get()); 15565b40f27SMatt Arsenault } 15665b40f27SMatt Arsenault std::unique_ptr<MLModelRunner> Runner; 15765b40f27SMatt Arsenault }; 15865b40f27SMatt Arsenault 15965b40f27SMatt Arsenault // =================================== 16065b40f27SMatt Arsenault // Development mode-specifics 16165b40f27SMatt Arsenault // =================================== 16265b40f27SMatt Arsenault // 16365b40f27SMatt Arsenault // Features we log 16465b40f27SMatt Arsenault #ifdef LLVM_HAVE_TFLITE 16565b40f27SMatt Arsenault static const TensorSpec Reward = TensorSpec::createSpec<float>("reward", {1}); 16665b40f27SMatt Arsenault 16765b40f27SMatt Arsenault #define _DECL_TRAIN_FEATURES(type, name, shape, _) \ 16865b40f27SMatt Arsenault TensorSpec::createSpec<type>(std::string("action_") + #name, shape), 16965b40f27SMatt Arsenault 17065b40f27SMatt Arsenault static const std::vector<TensorSpec> TrainingInputFeatures{ 17165b40f27SMatt Arsenault {RA_PRIORITY_FEATURES_LIST(_DECL_TRAIN_FEATURES) 17265b40f27SMatt Arsenault TensorSpec::createSpec<float>("action_discount", {1}), 17365b40f27SMatt Arsenault TensorSpec::createSpec<int32_t>("action_step_type", {1}), 17465b40f27SMatt Arsenault TensorSpec::createSpec<float>("action_reward", {1})}}; 17565b40f27SMatt Arsenault #undef _DECL_TRAIN_FEATURES 17665b40f27SMatt Arsenault 17765b40f27SMatt Arsenault class DevelopmentModePriorityAdvisor : public MLPriorityAdvisor { 17865b40f27SMatt Arsenault public: 17965b40f27SMatt Arsenault DevelopmentModePriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA, 18065b40f27SMatt Arsenault SlotIndexes *const Indexes, 18165b40f27SMatt Arsenault MLModelRunner *Runner, Logger *Log) 18265b40f27SMatt Arsenault : MLPriorityAdvisor(MF, RA, Indexes, Runner), Log(Log) {} 18365b40f27SMatt Arsenault 18465b40f27SMatt Arsenault private: 18565b40f27SMatt Arsenault unsigned getPriority(const LiveInterval &LI) const override; 18665b40f27SMatt Arsenault Logger *const Log; 18765b40f27SMatt Arsenault }; 18865b40f27SMatt Arsenault 18965b40f27SMatt Arsenault class DevelopmentModePriorityAdvisorAnalysis final 19065b40f27SMatt Arsenault : public RegAllocPriorityAdvisorAnalysis { 19165b40f27SMatt Arsenault public: 19265b40f27SMatt Arsenault DevelopmentModePriorityAdvisorAnalysis() 19365b40f27SMatt Arsenault : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Development) {} 19465b40f27SMatt Arsenault // support for isa<> and dyn_cast. 19565b40f27SMatt Arsenault static bool classof(const RegAllocPriorityAdvisorAnalysis *R) { 19665b40f27SMatt Arsenault return R->getAdvisorMode() == AdvisorMode::Development; 19765b40f27SMatt Arsenault } 19865b40f27SMatt Arsenault 19965b40f27SMatt Arsenault void logRewardIfNeeded(const MachineFunction &MF, 20065b40f27SMatt Arsenault llvm::function_ref<float()> GetReward) override { 20165b40f27SMatt Arsenault if (!Log || !Log->hasAnyObservationForContext(MF.getName())) 20265b40f27SMatt Arsenault return; 20365b40f27SMatt Arsenault // The function pass manager would run all the function passes for a 20465b40f27SMatt Arsenault // function, so we assume the last context belongs to this function. If 20565b40f27SMatt Arsenault // this invariant ever changes, we can implement at that time switching 20665b40f27SMatt Arsenault // contexts. At this point, it'd be an error 20765b40f27SMatt Arsenault if (Log->currentContext() != MF.getName()) { 20865b40f27SMatt Arsenault MF.getFunction().getContext().emitError( 20965b40f27SMatt Arsenault "The training log context shouldn't have had changed."); 21065b40f27SMatt Arsenault } 21165b40f27SMatt Arsenault if (Log->hasObservationInProgress()) 21265b40f27SMatt Arsenault Log->logReward<float>(GetReward()); 21365b40f27SMatt Arsenault } 21465b40f27SMatt Arsenault 21565b40f27SMatt Arsenault private: 21665b40f27SMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override { 21765b40f27SMatt Arsenault AU.setPreservesAll(); 218*4010f894Spaperchalice AU.addRequired<SlotIndexesWrapperPass>(); 21965b40f27SMatt Arsenault RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU); 22065b40f27SMatt Arsenault } 22165b40f27SMatt Arsenault 22265b40f27SMatt Arsenault // Save all the logs (when requested). 22365b40f27SMatt Arsenault bool doInitialization(Module &M) override { 22465b40f27SMatt Arsenault LLVMContext &Ctx = M.getContext(); 22565b40f27SMatt Arsenault if (ModelUnderTraining.empty() && TrainingLog.empty()) { 22665b40f27SMatt Arsenault Ctx.emitError("Regalloc development mode should be requested with at " 22765b40f27SMatt Arsenault "least logging enabled and/or a training model"); 22865b40f27SMatt Arsenault return false; 22965b40f27SMatt Arsenault } 23065b40f27SMatt Arsenault if (ModelUnderTraining.empty()) 23165b40f27SMatt Arsenault Runner = std::make_unique<NoInferenceModelRunner>(Ctx, InputFeatures); 23265b40f27SMatt Arsenault else 23365b40f27SMatt Arsenault Runner = ModelUnderTrainingRunner::createAndEnsureValid( 23465b40f27SMatt Arsenault Ctx, ModelUnderTraining, DecisionName, TrainingInputFeatures); 23565b40f27SMatt Arsenault if (!Runner) { 23665b40f27SMatt Arsenault Ctx.emitError("Regalloc: could not set up the model runner"); 23765b40f27SMatt Arsenault return false; 23865b40f27SMatt Arsenault } 23965b40f27SMatt Arsenault if (TrainingLog.empty()) 24065b40f27SMatt Arsenault return false; 24165b40f27SMatt Arsenault std::error_code EC; 24265b40f27SMatt Arsenault auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC); 24365b40f27SMatt Arsenault if (EC) { 24465b40f27SMatt Arsenault M.getContext().emitError(EC.message() + ":" + TrainingLog); 24565b40f27SMatt Arsenault return false; 24665b40f27SMatt Arsenault } 24765b40f27SMatt Arsenault std::vector<TensorSpec> LFS = InputFeatures; 24865b40f27SMatt Arsenault if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(Runner.get())) 24965b40f27SMatt Arsenault append_range(LFS, MUTR->extraOutputsForLoggingSpecs()); 25065b40f27SMatt Arsenault // We always log the output; in particular, if we're not evaluating, we 25165b40f27SMatt Arsenault // don't have an output spec json file. That's why we handle the 25265b40f27SMatt Arsenault // 'normal' output separately. 25365b40f27SMatt Arsenault LFS.push_back(DecisionSpec); 25465b40f27SMatt Arsenault 25565b40f27SMatt Arsenault Log = std::make_unique<Logger>(std::move(OS), LFS, Reward, 25665b40f27SMatt Arsenault /*IncludeReward*/ true); 25765b40f27SMatt Arsenault return false; 25865b40f27SMatt Arsenault } 25965b40f27SMatt Arsenault 26065b40f27SMatt Arsenault std::unique_ptr<RegAllocPriorityAdvisor> 26165b40f27SMatt Arsenault getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override { 26265b40f27SMatt Arsenault if (!Runner) 26365b40f27SMatt Arsenault return nullptr; 26465b40f27SMatt Arsenault if (Log) { 26565b40f27SMatt Arsenault Log->switchContext(MF.getName()); 26665b40f27SMatt Arsenault } 26765b40f27SMatt Arsenault 26865b40f27SMatt Arsenault return std::make_unique<DevelopmentModePriorityAdvisor>( 269*4010f894Spaperchalice MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI(), Runner.get(), 270*4010f894Spaperchalice Log.get()); 27165b40f27SMatt Arsenault } 27265b40f27SMatt Arsenault 27365b40f27SMatt Arsenault std::unique_ptr<MLModelRunner> Runner; 27465b40f27SMatt Arsenault std::unique_ptr<Logger> Log; 27565b40f27SMatt Arsenault }; 27665b40f27SMatt Arsenault #endif //#ifdef LLVM_HAVE_TFLITE 27765b40f27SMatt Arsenault 27865b40f27SMatt Arsenault } // namespace llvm 27965b40f27SMatt Arsenault 28065b40f27SMatt Arsenault RegAllocPriorityAdvisorAnalysis *llvm::createReleaseModePriorityAdvisor() { 28165b40f27SMatt Arsenault return llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() || 28265b40f27SMatt Arsenault !InteractiveChannelBaseName.empty() 28365b40f27SMatt Arsenault ? new ReleaseModePriorityAdvisorAnalysis() 28465b40f27SMatt Arsenault : nullptr; 28565b40f27SMatt Arsenault } 28665b40f27SMatt Arsenault 28765b40f27SMatt Arsenault MLPriorityAdvisor::MLPriorityAdvisor(const MachineFunction &MF, 28865b40f27SMatt Arsenault const RAGreedy &RA, 28965b40f27SMatt Arsenault SlotIndexes *const Indexes, 29065b40f27SMatt Arsenault MLModelRunner *Runner) 29165b40f27SMatt Arsenault : RegAllocPriorityAdvisor(MF, RA, Indexes), DefaultAdvisor(MF, RA, Indexes), 29265b40f27SMatt Arsenault Runner(std::move(Runner)) { 29365b40f27SMatt Arsenault assert(this->Runner); 29465b40f27SMatt Arsenault Runner->switchContext(MF.getName()); 29565b40f27SMatt Arsenault } 29665b40f27SMatt Arsenault 29765b40f27SMatt Arsenault float MLPriorityAdvisor::getPriorityImpl(const LiveInterval &LI) const { 29865b40f27SMatt Arsenault const unsigned Size = LI.getSize(); 29965b40f27SMatt Arsenault LiveRangeStage Stage = RA.getExtraInfo().getStage(LI); 30065b40f27SMatt Arsenault 30165b40f27SMatt Arsenault *Runner->getTensor<int64_t>(0) = static_cast<int64_t>(Size); 30265b40f27SMatt Arsenault *Runner->getTensor<int64_t>(1) = static_cast<int64_t>(Stage); 30365b40f27SMatt Arsenault *Runner->getTensor<float>(2) = static_cast<float>(LI.weight()); 30465b40f27SMatt Arsenault 30565b40f27SMatt Arsenault return Runner->evaluate<float>(); 30665b40f27SMatt Arsenault } 30765b40f27SMatt Arsenault 30865b40f27SMatt Arsenault unsigned MLPriorityAdvisor::getPriority(const LiveInterval &LI) const { 30965b40f27SMatt Arsenault return static_cast<unsigned>(getPriorityImpl(LI)); 31065b40f27SMatt Arsenault } 31165b40f27SMatt Arsenault 31265b40f27SMatt Arsenault #ifdef LLVM_HAVE_TFLITE 31365b40f27SMatt Arsenault RegAllocPriorityAdvisorAnalysis *llvm::createDevelopmentModePriorityAdvisor() { 31465b40f27SMatt Arsenault return new DevelopmentModePriorityAdvisorAnalysis(); 31565b40f27SMatt Arsenault } 31665b40f27SMatt Arsenault 31765b40f27SMatt Arsenault unsigned 31865b40f27SMatt Arsenault DevelopmentModePriorityAdvisor::getPriority(const LiveInterval &LI) const { 31965b40f27SMatt Arsenault double Prio = 0; 32065b40f27SMatt Arsenault 32165b40f27SMatt Arsenault if (isa<ModelUnderTrainingRunner>(getRunner())) { 32265b40f27SMatt Arsenault Prio = MLPriorityAdvisor::getPriorityImpl(LI); 32365b40f27SMatt Arsenault } else { 32465b40f27SMatt Arsenault Prio = getDefaultAdvisor().getPriority(LI); 32565b40f27SMatt Arsenault } 32665b40f27SMatt Arsenault 32765b40f27SMatt Arsenault if (TrainingLog.empty()) 32865b40f27SMatt Arsenault return Prio; 32965b40f27SMatt Arsenault 33065b40f27SMatt Arsenault // TODO(mtrofin): when we support optional rewards, this can go away. In the 33165b40f27SMatt Arsenault // meantime, we log the "pretend" reward (0) for the previous observation 33265b40f27SMatt Arsenault // before starting a new one. 33365b40f27SMatt Arsenault if (Log->hasObservationInProgress()) 33465b40f27SMatt Arsenault Log->logReward<float>(0.0); 33565b40f27SMatt Arsenault 33665b40f27SMatt Arsenault Log->startObservation(); 33765b40f27SMatt Arsenault size_t CurrentFeature = 0; 33865b40f27SMatt Arsenault for (; CurrentFeature < InputFeatures.size(); ++CurrentFeature) { 33965b40f27SMatt Arsenault Log->logTensorValue(CurrentFeature, 34065b40f27SMatt Arsenault reinterpret_cast<const char *>( 34165b40f27SMatt Arsenault getRunner().getTensorUntyped(CurrentFeature))); 34265b40f27SMatt Arsenault } 34365b40f27SMatt Arsenault 34465b40f27SMatt Arsenault if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(&getRunner())) { 34565b40f27SMatt Arsenault for (size_t I = 0; I < MUTR->extraOutputsForLoggingSpecs().size(); 34665b40f27SMatt Arsenault ++I, ++CurrentFeature) 34765b40f27SMatt Arsenault Log->logTensorValue( 34865b40f27SMatt Arsenault CurrentFeature, 34965b40f27SMatt Arsenault reinterpret_cast<const char *>(MUTR->getUntypedExtraOutputValue(I))); 35065b40f27SMatt Arsenault } 35165b40f27SMatt Arsenault 35265b40f27SMatt Arsenault float Ret = static_cast<float>(Prio); 35365b40f27SMatt Arsenault Log->logTensorValue(CurrentFeature, reinterpret_cast<const char *>(&Ret)); 35465b40f27SMatt Arsenault Log->endObservation(); 35565b40f27SMatt Arsenault 35665b40f27SMatt Arsenault return static_cast<unsigned>(Prio); 35765b40f27SMatt Arsenault } 35865b40f27SMatt Arsenault 35965b40f27SMatt Arsenault #endif // #ifdef LLVM_HAVE_TFLITE 360