xref: /llvm-project/bolt/lib/Passes/ThreeWayBranch.cpp (revision 2f09f445b2d6b3ef197aecd8d1e06d08140380f3)
1 //===- bolt/Passes/ThreeWayBranch.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the ThreeWayBranch class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/ThreeWayBranch.h"
14 
15 #include <numeric>
16 
17 using namespace llvm;
18 
19 namespace llvm {
20 namespace bolt {
21 
22 bool ThreeWayBranch::shouldRunOnFunction(BinaryFunction &Function) {
23   BinaryContext &BC = Function.getBinaryContext();
24   BinaryFunction::BasicBlockOrderType BlockLayout = Function.getLayout();
25   for (BinaryBasicBlock *BB : BlockLayout) {
26     for (MCInst &Inst : *BB) {
27       if (BC.MIB->isPacked(Inst))
28         return false;
29     }
30   }
31   return true;
32 }
33 
34 void ThreeWayBranch::runOnFunction(BinaryFunction &Function) {
35   BinaryContext &BC = Function.getBinaryContext();
36   MCContext *Ctx = BC.Ctx.get();
37   // New blocks will be added and layout will change,
38   // so make a copy here to iterate over the original layout
39   BinaryFunction::BasicBlockOrderType BlockLayout = Function.getLayout();
40   for (BinaryBasicBlock *BB : BlockLayout) {
41     // The block must be hot
42     if (BB->getExecutionCount() == 0 ||
43         BB->getExecutionCount() == BinaryBasicBlock::COUNT_NO_PROFILE)
44       continue;
45     // with two successors
46     if (BB->succ_size() != 2)
47       continue;
48     // no jump table
49     if (BB->hasJumpTable())
50       continue;
51 
52     BinaryBasicBlock *FalseSucc = BB->getConditionalSuccessor(false);
53     BinaryBasicBlock *TrueSucc = BB->getConditionalSuccessor(true);
54 
55     // One of BB's successors must have only one instruction that is a
56     // conditional jump
57     if ((FalseSucc->succ_size() != 2 || FalseSucc->size() != 1) &&
58         (TrueSucc->succ_size() != 2 || TrueSucc->size() != 1))
59       continue;
60 
61     // SecondBranch has the second conditional jump
62     BinaryBasicBlock *SecondBranch = FalseSucc;
63     BinaryBasicBlock *FirstEndpoint = TrueSucc;
64     if (FalseSucc->succ_size() != 2) {
65       SecondBranch = TrueSucc;
66       FirstEndpoint = FalseSucc;
67     }
68 
69     BinaryBasicBlock *SecondEndpoint =
70         SecondBranch->getConditionalSuccessor(false);
71     BinaryBasicBlock *ThirdEndpoint =
72         SecondBranch->getConditionalSuccessor(true);
73 
74     // Make sure we can modify the jump in SecondBranch without disturbing any
75     // other paths
76     if (SecondBranch->pred_size() != 1)
77       continue;
78 
79     // Get Jump Instructions
80     MCInst *FirstJump = BB->getLastNonPseudoInstr();
81     MCInst *SecondJump = SecondBranch->getLastNonPseudoInstr();
82 
83     // Get condition codes
84     unsigned FirstCC = BC.MIB->getCondCode(*FirstJump);
85     if (SecondBranch != FalseSucc)
86       FirstCC = BC.MIB->getInvertedCondCode(FirstCC);
87     // ThirdCC = ThirdCond && !FirstCC = !(!ThirdCond ||
88     // !(!FirstCC)) = !(!ThirdCond || FirstCC)
89     unsigned ThirdCC =
90         BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr(
91             BC.MIB->getInvertedCondCode(BC.MIB->getCondCode(*SecondJump)),
92             FirstCC));
93     // SecondCC = !ThirdCond && !FirstCC = !(!(!ThirdCond) ||
94     // !(!FirstCC)) = !(ThirdCond || FirstCC)
95     unsigned SecondCC =
96         BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr(
97             BC.MIB->getCondCode(*SecondJump), FirstCC));
98 
99     if (!BC.MIB->isValidCondCode(FirstCC) ||
100         !BC.MIB->isValidCondCode(ThirdCC) || !BC.MIB->isValidCondCode(SecondCC))
101       continue;
102 
103     std::vector<std::pair<BinaryBasicBlock *, unsigned>> Blocks;
104     Blocks.push_back(std::make_pair(FirstEndpoint, FirstCC));
105     Blocks.push_back(std::make_pair(SecondEndpoint, SecondCC));
106     Blocks.push_back(std::make_pair(ThirdEndpoint, ThirdCC));
107 
108     std::sort(Blocks.begin(), Blocks.end(),
109               [&](const std::pair<BinaryBasicBlock *, unsigned> A,
110                   const std::pair<BinaryBasicBlock *, unsigned> B) {
111                 return A.first->getExecutionCount() <
112                        B.first->getExecutionCount();
113               });
114 
115     uint64_t NewSecondBranchCount = Blocks[1].first->getExecutionCount() +
116                                     Blocks[0].first->getExecutionCount();
117     bool SecondBranchBigger =
118         NewSecondBranchCount > Blocks[2].first->getExecutionCount();
119 
120     BB->removeAllSuccessors();
121     if (SecondBranchBigger) {
122       BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount());
123       BB->addSuccessor(SecondBranch, NewSecondBranchCount);
124     } else {
125       BB->addSuccessor(SecondBranch, NewSecondBranchCount);
126       BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount());
127     }
128 
129     // Remove and add so there is no duplicate successors
130     SecondBranch->removeAllSuccessors();
131     SecondBranch->addSuccessor(Blocks[0].first,
132                                Blocks[0].first->getExecutionCount());
133     SecondBranch->addSuccessor(Blocks[1].first,
134                                Blocks[1].first->getExecutionCount());
135 
136     SecondBranch->setExecutionCount(NewSecondBranchCount);
137 
138     // Replace the branch condition to fallthrough for the most common block
139     if (SecondBranchBigger) {
140       BC.MIB->replaceBranchCondition(*FirstJump, Blocks[2].first->getLabel(),
141                                      Ctx, Blocks[2].second);
142     } else {
143       BC.MIB->replaceBranchCondition(
144           *FirstJump, SecondBranch->getLabel(), Ctx,
145           BC.MIB->getInvertedCondCode(Blocks[2].second));
146     }
147 
148     // Replace the branch condition to fallthrough for the second most common
149     // block
150     BC.MIB->replaceBranchCondition(*SecondJump, Blocks[0].first->getLabel(),
151                                    Ctx, Blocks[0].second);
152 
153     ++BranchesAltered;
154   }
155 }
156 
157 void ThreeWayBranch::runOnFunctions(BinaryContext &BC) {
158   for (auto &It : BC.getBinaryFunctions()) {
159     BinaryFunction &Function = It.second;
160     if (!shouldRunOnFunction(Function))
161       continue;
162     runOnFunction(Function);
163   }
164 
165   outs() << "BOLT-INFO: number of three way branches order changed: "
166          << BranchesAltered << "\n";
167 }
168 
169 } // end namespace bolt
170 } // end namespace llvm
171