xref: /llvm-project/bolt/lib/Passes/ThreeWayBranch.cpp (revision d55dfeaf32e8a88f0c6e7240fb2a8a1b57d89380)
1 //===- bolt/Passes/ThreeWayBranch.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the ThreeWayBranch class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/ThreeWayBranch.h"
14 
15 using namespace llvm;
16 
17 namespace llvm {
18 namespace bolt {
19 
20 bool ThreeWayBranch::shouldRunOnFunction(BinaryFunction &Function) {
21   BinaryContext &BC = Function.getBinaryContext();
22   for (const BinaryBasicBlock &BB : Function)
23     for (const MCInst &Inst : BB)
24       if (BC.MIB->isPacked(Inst))
25         return false;
26   return true;
27 }
28 
29 void ThreeWayBranch::runOnFunction(BinaryFunction &Function) {
30   BinaryContext &BC = Function.getBinaryContext();
31   MCContext *Ctx = BC.Ctx.get();
32   // New blocks will be added and layout will change,
33   // so make a copy here to iterate over the original layout
34   BinaryFunction::BasicBlockOrderType BlockLayout = Function.getLayout();
35   for (BinaryBasicBlock *BB : BlockLayout) {
36     // The block must be hot
37     if (BB->getExecutionCount() == 0 ||
38         BB->getExecutionCount() == BinaryBasicBlock::COUNT_NO_PROFILE)
39       continue;
40     // with two successors
41     if (BB->succ_size() != 2)
42       continue;
43     // no jump table
44     if (BB->hasJumpTable())
45       continue;
46 
47     BinaryBasicBlock *FalseSucc = BB->getConditionalSuccessor(false);
48     BinaryBasicBlock *TrueSucc = BB->getConditionalSuccessor(true);
49 
50     // One of BB's successors must have only one instruction that is a
51     // conditional jump
52     if ((FalseSucc->succ_size() != 2 || FalseSucc->size() != 1) &&
53         (TrueSucc->succ_size() != 2 || TrueSucc->size() != 1))
54       continue;
55 
56     // SecondBranch has the second conditional jump
57     BinaryBasicBlock *SecondBranch = FalseSucc;
58     BinaryBasicBlock *FirstEndpoint = TrueSucc;
59     if (FalseSucc->succ_size() != 2) {
60       SecondBranch = TrueSucc;
61       FirstEndpoint = FalseSucc;
62     }
63 
64     BinaryBasicBlock *SecondEndpoint =
65         SecondBranch->getConditionalSuccessor(false);
66     BinaryBasicBlock *ThirdEndpoint =
67         SecondBranch->getConditionalSuccessor(true);
68 
69     // Make sure we can modify the jump in SecondBranch without disturbing any
70     // other paths
71     if (SecondBranch->pred_size() != 1)
72       continue;
73 
74     // Get Jump Instructions
75     MCInst *FirstJump = BB->getLastNonPseudoInstr();
76     MCInst *SecondJump = SecondBranch->getLastNonPseudoInstr();
77 
78     // Get condition codes
79     unsigned FirstCC = BC.MIB->getCondCode(*FirstJump);
80     if (SecondBranch != FalseSucc)
81       FirstCC = BC.MIB->getInvertedCondCode(FirstCC);
82     // ThirdCC = ThirdCond && !FirstCC = !(!ThirdCond ||
83     // !(!FirstCC)) = !(!ThirdCond || FirstCC)
84     unsigned ThirdCC =
85         BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr(
86             BC.MIB->getInvertedCondCode(BC.MIB->getCondCode(*SecondJump)),
87             FirstCC));
88     // SecondCC = !ThirdCond && !FirstCC = !(!(!ThirdCond) ||
89     // !(!FirstCC)) = !(ThirdCond || FirstCC)
90     unsigned SecondCC =
91         BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr(
92             BC.MIB->getCondCode(*SecondJump), FirstCC));
93 
94     if (!BC.MIB->isValidCondCode(FirstCC) ||
95         !BC.MIB->isValidCondCode(ThirdCC) || !BC.MIB->isValidCondCode(SecondCC))
96       continue;
97 
98     std::vector<std::pair<BinaryBasicBlock *, unsigned>> Blocks;
99     Blocks.push_back(std::make_pair(FirstEndpoint, FirstCC));
100     Blocks.push_back(std::make_pair(SecondEndpoint, SecondCC));
101     Blocks.push_back(std::make_pair(ThirdEndpoint, ThirdCC));
102 
103     llvm::sort(Blocks, [&](const std::pair<BinaryBasicBlock *, unsigned> A,
104                            const std::pair<BinaryBasicBlock *, unsigned> B) {
105       return A.first->getExecutionCount() < B.first->getExecutionCount();
106     });
107 
108     uint64_t NewSecondBranchCount = Blocks[1].first->getExecutionCount() +
109                                     Blocks[0].first->getExecutionCount();
110     bool SecondBranchBigger =
111         NewSecondBranchCount > Blocks[2].first->getExecutionCount();
112 
113     BB->removeAllSuccessors();
114     if (SecondBranchBigger) {
115       BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount());
116       BB->addSuccessor(SecondBranch, NewSecondBranchCount);
117     } else {
118       BB->addSuccessor(SecondBranch, NewSecondBranchCount);
119       BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount());
120     }
121 
122     // Remove and add so there is no duplicate successors
123     SecondBranch->removeAllSuccessors();
124     SecondBranch->addSuccessor(Blocks[0].first,
125                                Blocks[0].first->getExecutionCount());
126     SecondBranch->addSuccessor(Blocks[1].first,
127                                Blocks[1].first->getExecutionCount());
128 
129     SecondBranch->setExecutionCount(NewSecondBranchCount);
130 
131     // Replace the branch condition to fallthrough for the most common block
132     if (SecondBranchBigger)
133       BC.MIB->replaceBranchCondition(*FirstJump, Blocks[2].first->getLabel(),
134                                      Ctx, Blocks[2].second);
135     else
136       BC.MIB->replaceBranchCondition(
137           *FirstJump, SecondBranch->getLabel(), Ctx,
138           BC.MIB->getInvertedCondCode(Blocks[2].second));
139 
140     // Replace the branch condition to fallthrough for the second most common
141     // block
142     BC.MIB->replaceBranchCondition(*SecondJump, Blocks[0].first->getLabel(),
143                                    Ctx, Blocks[0].second);
144 
145     ++BranchesAltered;
146   }
147 }
148 
149 void ThreeWayBranch::runOnFunctions(BinaryContext &BC) {
150   for (auto &It : BC.getBinaryFunctions()) {
151     BinaryFunction &Function = It.second;
152     if (!shouldRunOnFunction(Function))
153       continue;
154     runOnFunction(Function);
155   }
156 
157   outs() << "BOLT-INFO: number of three way branches order changed: "
158          << BranchesAltered << "\n";
159 }
160 
161 } // end namespace bolt
162 } // end namespace llvm
163