xref: /llvm-project/llvm/lib/CodeGen/RegAllocScore.cpp (revision 735ab61ac828bd61398e6847d60e308fdf2b54ec)
1 //===- RegAllocScore.cpp - evaluate regalloc policy quality ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// Calculate a measure of the register allocation policy quality. This is used
9 /// to construct a reward for the training of the ML-driven allocation policy.
10 /// Currently, the score is the sum of the machine basic block frequency-weighed
11 /// number of loads, stores, copies, and remat instructions, each factored with
12 /// a relative weight.
13 //===----------------------------------------------------------------------===//
14 
15 #include "RegAllocScore.h"
16 #include "llvm/CodeGen/MachineBasicBlock.h"
17 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
18 #include "llvm/CodeGen/MachineFunction.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/TargetInstrInfo.h"
21 #include "llvm/CodeGen/TargetSubtargetInfo.h"
22 #include "llvm/MC/MCInstrDesc.h"
23 #include "llvm/Support/CommandLine.h"
24 
25 using namespace llvm;
26 cl::opt<double> CopyWeight("regalloc-copy-weight", cl::init(0.2), cl::Hidden);
27 cl::opt<double> LoadWeight("regalloc-load-weight", cl::init(4.0), cl::Hidden);
28 cl::opt<double> StoreWeight("regalloc-store-weight", cl::init(1.0), cl::Hidden);
29 cl::opt<double> CheapRematWeight("regalloc-cheap-remat-weight", cl::init(0.2),
30                                  cl::Hidden);
31 cl::opt<double> ExpensiveRematWeight("regalloc-expensive-remat-weight",
32                                      cl::init(1.0), cl::Hidden);
33 #define DEBUG_TYPE "regalloc-score"
34 
35 RegAllocScore &RegAllocScore::operator+=(const RegAllocScore &Other) {
36   CopyCounts += Other.copyCounts();
37   LoadCounts += Other.loadCounts();
38   StoreCounts += Other.storeCounts();
39   LoadStoreCounts += Other.loadStoreCounts();
40   CheapRematCounts += Other.cheapRematCounts();
41   ExpensiveRematCounts += Other.expensiveRematCounts();
42   return *this;
43 }
44 
45 bool RegAllocScore::operator==(const RegAllocScore &Other) const {
46   return copyCounts() == Other.copyCounts() &&
47          loadCounts() == Other.loadCounts() &&
48          storeCounts() == Other.storeCounts() &&
49          loadStoreCounts() == Other.loadStoreCounts() &&
50          cheapRematCounts() == Other.cheapRematCounts() &&
51          expensiveRematCounts() == Other.expensiveRematCounts();
52 }
53 
54 bool RegAllocScore::operator!=(const RegAllocScore &Other) const {
55   return !(*this == Other);
56 }
57 
58 double RegAllocScore::getScore() const {
59   double Ret = 0.0;
60   Ret += CopyWeight * copyCounts();
61   Ret += LoadWeight * loadCounts();
62   Ret += StoreWeight * storeCounts();
63   Ret += (LoadWeight + StoreWeight) * loadStoreCounts();
64   Ret += CheapRematWeight * cheapRematCounts();
65   Ret += ExpensiveRematWeight * expensiveRematCounts();
66 
67   return Ret;
68 }
69 
70 RegAllocScore
71 llvm::calculateRegAllocScore(const MachineFunction &MF,
72                              const MachineBlockFrequencyInfo &MBFI) {
73   return calculateRegAllocScore(
74       MF,
75       [&](const MachineBasicBlock &MBB) {
76         return MBFI.getBlockFreqRelativeToEntryBlock(&MBB);
77       },
78       [&](const MachineInstr &MI) {
79         return MF.getSubtarget().getInstrInfo()->isTriviallyReMaterializable(
80             MI);
81       });
82 }
83 
84 RegAllocScore llvm::calculateRegAllocScore(
85     const MachineFunction &MF,
86     llvm::function_ref<double(const MachineBasicBlock &)> GetBBFreq,
87     llvm::function_ref<bool(const MachineInstr &)>
88         IsTriviallyRematerializable) {
89   RegAllocScore Total;
90 
91   for (const MachineBasicBlock &MBB : MF) {
92     double BlockFreqRelativeToEntrypoint = GetBBFreq(MBB);
93     RegAllocScore MBBScore;
94 
95     for (const MachineInstr &MI : MBB) {
96       if (MI.isDebugInstr() || MI.isKill() || MI.isInlineAsm()) {
97         continue;
98       }
99       if (MI.isCopy()) {
100         MBBScore.onCopy(BlockFreqRelativeToEntrypoint);
101       } else if (IsTriviallyRematerializable(MI)) {
102         if (MI.getDesc().isAsCheapAsAMove()) {
103           MBBScore.onCheapRemat(BlockFreqRelativeToEntrypoint);
104         } else {
105           MBBScore.onExpensiveRemat(BlockFreqRelativeToEntrypoint);
106         }
107       } else if (MI.mayLoad() && MI.mayStore()) {
108         MBBScore.onLoadStore(BlockFreqRelativeToEntrypoint);
109       } else if (MI.mayLoad()) {
110         MBBScore.onLoad(BlockFreqRelativeToEntrypoint);
111       } else if (MI.mayStore()) {
112         MBBScore.onStore(BlockFreqRelativeToEntrypoint);
113       }
114     }
115     Total += MBBScore;
116   }
117   return Total;
118 }
119