xref: /llvm-project/bolt/lib/Passes/ThreeWayBranch.cpp (revision f92ab6af35343df1f82c3a85feb75d41b6be382e)
1 //===- bolt/Passes/ThreeWayBranch.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the ThreeWayBranch class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/ThreeWayBranch.h"
14 
15 #include <numeric>
16 
17 using namespace llvm;
18 
19 namespace llvm {
20 namespace bolt {
21 
22 bool ThreeWayBranch::shouldRunOnFunction(BinaryFunction &Function) {
23   BinaryContext &BC = Function.getBinaryContext();
24   BinaryFunction::BasicBlockOrderType BlockLayout = Function.getLayout();
25   for (BinaryBasicBlock *BB : BlockLayout)
26     for (MCInst &Inst : *BB)
27       if (BC.MIB->isPacked(Inst))
28         return false;
29   return true;
30 }
31 
32 void ThreeWayBranch::runOnFunction(BinaryFunction &Function) {
33   BinaryContext &BC = Function.getBinaryContext();
34   MCContext *Ctx = BC.Ctx.get();
35   // New blocks will be added and layout will change,
36   // so make a copy here to iterate over the original layout
37   BinaryFunction::BasicBlockOrderType BlockLayout = Function.getLayout();
38   for (BinaryBasicBlock *BB : BlockLayout) {
39     // The block must be hot
40     if (BB->getExecutionCount() == 0 ||
41         BB->getExecutionCount() == BinaryBasicBlock::COUNT_NO_PROFILE)
42       continue;
43     // with two successors
44     if (BB->succ_size() != 2)
45       continue;
46     // no jump table
47     if (BB->hasJumpTable())
48       continue;
49 
50     BinaryBasicBlock *FalseSucc = BB->getConditionalSuccessor(false);
51     BinaryBasicBlock *TrueSucc = BB->getConditionalSuccessor(true);
52 
53     // One of BB's successors must have only one instruction that is a
54     // conditional jump
55     if ((FalseSucc->succ_size() != 2 || FalseSucc->size() != 1) &&
56         (TrueSucc->succ_size() != 2 || TrueSucc->size() != 1))
57       continue;
58 
59     // SecondBranch has the second conditional jump
60     BinaryBasicBlock *SecondBranch = FalseSucc;
61     BinaryBasicBlock *FirstEndpoint = TrueSucc;
62     if (FalseSucc->succ_size() != 2) {
63       SecondBranch = TrueSucc;
64       FirstEndpoint = FalseSucc;
65     }
66 
67     BinaryBasicBlock *SecondEndpoint =
68         SecondBranch->getConditionalSuccessor(false);
69     BinaryBasicBlock *ThirdEndpoint =
70         SecondBranch->getConditionalSuccessor(true);
71 
72     // Make sure we can modify the jump in SecondBranch without disturbing any
73     // other paths
74     if (SecondBranch->pred_size() != 1)
75       continue;
76 
77     // Get Jump Instructions
78     MCInst *FirstJump = BB->getLastNonPseudoInstr();
79     MCInst *SecondJump = SecondBranch->getLastNonPseudoInstr();
80 
81     // Get condition codes
82     unsigned FirstCC = BC.MIB->getCondCode(*FirstJump);
83     if (SecondBranch != FalseSucc)
84       FirstCC = BC.MIB->getInvertedCondCode(FirstCC);
85     // ThirdCC = ThirdCond && !FirstCC = !(!ThirdCond ||
86     // !(!FirstCC)) = !(!ThirdCond || FirstCC)
87     unsigned ThirdCC =
88         BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr(
89             BC.MIB->getInvertedCondCode(BC.MIB->getCondCode(*SecondJump)),
90             FirstCC));
91     // SecondCC = !ThirdCond && !FirstCC = !(!(!ThirdCond) ||
92     // !(!FirstCC)) = !(ThirdCond || FirstCC)
93     unsigned SecondCC =
94         BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr(
95             BC.MIB->getCondCode(*SecondJump), FirstCC));
96 
97     if (!BC.MIB->isValidCondCode(FirstCC) ||
98         !BC.MIB->isValidCondCode(ThirdCC) || !BC.MIB->isValidCondCode(SecondCC))
99       continue;
100 
101     std::vector<std::pair<BinaryBasicBlock *, unsigned>> Blocks;
102     Blocks.push_back(std::make_pair(FirstEndpoint, FirstCC));
103     Blocks.push_back(std::make_pair(SecondEndpoint, SecondCC));
104     Blocks.push_back(std::make_pair(ThirdEndpoint, ThirdCC));
105 
106     std::sort(Blocks.begin(), Blocks.end(),
107               [&](const std::pair<BinaryBasicBlock *, unsigned> A,
108                   const std::pair<BinaryBasicBlock *, unsigned> B) {
109                 return A.first->getExecutionCount() <
110                        B.first->getExecutionCount();
111               });
112 
113     uint64_t NewSecondBranchCount = Blocks[1].first->getExecutionCount() +
114                                     Blocks[0].first->getExecutionCount();
115     bool SecondBranchBigger =
116         NewSecondBranchCount > Blocks[2].first->getExecutionCount();
117 
118     BB->removeAllSuccessors();
119     if (SecondBranchBigger) {
120       BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount());
121       BB->addSuccessor(SecondBranch, NewSecondBranchCount);
122     } else {
123       BB->addSuccessor(SecondBranch, NewSecondBranchCount);
124       BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount());
125     }
126 
127     // Remove and add so there is no duplicate successors
128     SecondBranch->removeAllSuccessors();
129     SecondBranch->addSuccessor(Blocks[0].first,
130                                Blocks[0].first->getExecutionCount());
131     SecondBranch->addSuccessor(Blocks[1].first,
132                                Blocks[1].first->getExecutionCount());
133 
134     SecondBranch->setExecutionCount(NewSecondBranchCount);
135 
136     // Replace the branch condition to fallthrough for the most common block
137     if (SecondBranchBigger)
138       BC.MIB->replaceBranchCondition(*FirstJump, Blocks[2].first->getLabel(),
139                                      Ctx, Blocks[2].second);
140     else
141       BC.MIB->replaceBranchCondition(
142           *FirstJump, SecondBranch->getLabel(), Ctx,
143           BC.MIB->getInvertedCondCode(Blocks[2].second));
144 
145     // Replace the branch condition to fallthrough for the second most common
146     // block
147     BC.MIB->replaceBranchCondition(*SecondJump, Blocks[0].first->getLabel(),
148                                    Ctx, Blocks[0].second);
149 
150     ++BranchesAltered;
151   }
152 }
153 
154 void ThreeWayBranch::runOnFunctions(BinaryContext &BC) {
155   for (auto &It : BC.getBinaryFunctions()) {
156     BinaryFunction &Function = It.second;
157     if (!shouldRunOnFunction(Function))
158       continue;
159     runOnFunction(Function);
160   }
161 
162   outs() << "BOLT-INFO: number of three way branches order changed: "
163          << BranchesAltered << "\n";
164 }
165 
166 } // end namespace bolt
167 } // end namespace llvm
168