xref: /llvm-project/bolt/lib/Passes/RetpolineInsertion.cpp (revision 0a5edb4de408ae0405f85c3e4c6da5233f185f63)
1 //===- bolt/Passes/RetpolineInsertion.cpp ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements RetpolineInsertion class, which replaces indirect
10 // branches (calls and jumps) with calls to retpolines to protect against branch
11 // target injection attacks.
12 // A unique retpoline is created for each register holding the address of the
13 // callee, if the callee address is in memory %r11 is used if available to
14 // hold the address of the callee before calling the retpoline, otherwise an
15 // address pattern specific retpoline is called where the callee address is
16 // loaded inside the retpoline.
17 // The user can determine when to assume %r11 available using r11-availability
18 // option, by default %r11 is assumed not available.
19 // Adding lfence instruction to the body of the speculate code is enabled by
20 // default and can be controlled by the user using retpoline-lfence option.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "bolt/Passes/RetpolineInsertion.h"
25 #include "llvm/MC/MCInstPrinter.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 #define DEBUG_TYPE "bolt-retpoline"
29 
30 using namespace llvm;
31 using namespace bolt;
32 namespace opts {
33 
34 extern cl::OptionCategory BoltCategory;
35 
36 llvm::cl::opt<bool> InsertRetpolines("insert-retpolines",
37                                      cl::desc("run retpoline insertion pass"),
38                                      cl::cat(BoltCategory));
39 
40 llvm::cl::opt<bool>
41 RetpolineLfence("retpoline-lfence",
42   cl::desc("determine if lfence instruction should exist in the retpoline"),
43   cl::init(true),
44   cl::ZeroOrMore,
45   cl::Hidden,
46   cl::cat(BoltCategory));
47 
48 cl::opt<RetpolineInsertion::AvailabilityOptions> R11Availability(
49     "r11-availability",
50     cl::desc("determine the availability of r11 before indirect branches"),
51     cl::init(RetpolineInsertion::AvailabilityOptions::NEVER),
52     cl::values(clEnumValN(RetpolineInsertion::AvailabilityOptions::NEVER,
53                           "never", "r11 not available"),
54                clEnumValN(RetpolineInsertion::AvailabilityOptions::ALWAYS,
55                           "always", "r11 available before calls and jumps"),
56                clEnumValN(RetpolineInsertion::AvailabilityOptions::ABI, "abi",
57                           "r11 available before calls but not before jumps")),
58     cl::ZeroOrMore, cl::cat(BoltCategory));
59 
60 } // namespace opts
61 
62 namespace llvm {
63 namespace bolt {
64 
65 // Retpoline function structure:
66 // BB0: call BB2
67 // BB1: pause
68 //      lfence
69 //      jmp BB1
70 // BB2: mov %reg, (%rsp)
71 //      ret
72 // or
73 // BB2: push %r11
74 //      mov Address, %r11
75 //      mov %r11, 8(%rsp)
76 //      pop %r11
77 //      ret
78 BinaryFunction *createNewRetpoline(BinaryContext &BC,
79                                    const std::string &RetpolineTag,
80                                    const IndirectBranchInfo &BrInfo,
81                                    bool R11Available) {
82   auto &MIB = *BC.MIB;
83   MCContext &Ctx = *BC.Ctx.get();
84   LLVM_DEBUG(dbgs() << "BOLT-DEBUG: Creating a new retpoline function["
85                     << RetpolineTag << "]\n");
86 
87   BinaryFunction *NewRetpoline =
88       BC.createInjectedBinaryFunction(RetpolineTag, true);
89   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBlocks(3);
90   for (int I = 0; I < 3; I++) {
91     MCSymbol *Symbol =
92         Ctx.createNamedTempSymbol(Twine(RetpolineTag + "_BB" + to_string(I)));
93     NewBlocks[I] = NewRetpoline->createBasicBlock(Symbol);
94     NewBlocks[I].get()->setCFIState(0);
95   }
96 
97   BinaryBasicBlock &BB0 = *NewBlocks[0].get();
98   BinaryBasicBlock &BB1 = *NewBlocks[1].get();
99   BinaryBasicBlock &BB2 = *NewBlocks[2].get();
100 
101   BB0.addSuccessor(&BB2, 0, 0);
102   BB1.addSuccessor(&BB1, 0, 0);
103 
104   // Build BB0
105   MCInst DirectCall;
106   MIB.createDirectCall(DirectCall, BB2.getLabel(), &Ctx, /*IsTailCall*/ false);
107   BB0.addInstruction(DirectCall);
108 
109   // Build BB1
110   MCInst Pause;
111   MIB.createPause(Pause);
112   BB1.addInstruction(Pause);
113 
114   if (opts::RetpolineLfence) {
115     MCInst Lfence;
116     MIB.createLfence(Lfence);
117     BB1.addInstruction(Lfence);
118   }
119 
120   InstructionListType Seq;
121   MIB.createShortJmp(Seq, BB1.getLabel(), &Ctx);
122   BB1.addInstructions(Seq.begin(), Seq.end());
123 
124   // Build BB2
125   if (BrInfo.isMem()) {
126     if (R11Available) {
127       MCInst StoreToStack;
128       MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 0,
129                             MIB.getX86R11(), 8);
130       BB2.addInstruction(StoreToStack);
131     } else {
132       MCInst PushR11;
133       MIB.createPushRegister(PushR11, MIB.getX86R11(), 8);
134       BB2.addInstruction(PushR11);
135 
136       MCInst LoadCalleeAddrs;
137       const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
138       MIB.createLoad(LoadCalleeAddrs, MemRef.BaseRegNum, MemRef.ScaleImm,
139                      MemRef.IndexRegNum, MemRef.DispImm, MemRef.DispExpr,
140                      MemRef.SegRegNum, MIB.getX86R11(), 8);
141 
142       BB2.addInstruction(LoadCalleeAddrs);
143 
144       MCInst StoreToStack;
145       MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 8,
146                             MIB.getX86R11(), 8);
147       BB2.addInstruction(StoreToStack);
148 
149       MCInst PopR11;
150       MIB.createPopRegister(PopR11, MIB.getX86R11(), 8);
151       BB2.addInstruction(PopR11);
152     }
153   } else if (BrInfo.isReg()) {
154     MCInst StoreToStack;
155     MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 0,
156                           BrInfo.BranchReg, 8);
157     BB2.addInstruction(StoreToStack);
158   } else {
159     llvm_unreachable("not expected");
160   }
161 
162   // return
163   MCInst Return;
164   MIB.createReturn(Return);
165   BB2.addInstruction(Return);
166   NewRetpoline->insertBasicBlocks(nullptr, std::move(NewBlocks),
167                                   /* UpdateLayout */ true,
168                                   /* UpdateCFIState */ false);
169 
170   NewRetpoline->updateState(BinaryFunction::State::CFG_Finalized);
171   return NewRetpoline;
172 }
173 
174 std::string createRetpolineFunctionTag(BinaryContext &BC,
175                                        const IndirectBranchInfo &BrInfo,
176                                        bool R11Available) {
177   std::string Tag;
178   llvm::raw_string_ostream TagOS(Tag);
179   TagOS << "__retpoline_";
180 
181   if (BrInfo.isReg()) {
182     BC.InstPrinter->printRegName(TagOS, BrInfo.BranchReg);
183     TagOS << "_";
184     return Tag;
185   }
186 
187   // Memory Branch
188   if (R11Available)
189     return "__retpoline_r11";
190 
191   const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
192 
193   TagOS << "mem_";
194 
195   if (MemRef.BaseRegNum != BC.MIB->getNoRegister())
196     BC.InstPrinter->printRegName(TagOS, MemRef.BaseRegNum);
197 
198   TagOS << "+";
199   if (MemRef.DispExpr)
200     MemRef.DispExpr->print(TagOS, BC.AsmInfo.get());
201   else
202     TagOS << MemRef.DispImm;
203 
204   if (MemRef.IndexRegNum != BC.MIB->getNoRegister()) {
205     TagOS << "+" << MemRef.ScaleImm << "*";
206     BC.InstPrinter->printRegName(TagOS, MemRef.IndexRegNum);
207   }
208 
209   if (MemRef.SegRegNum != BC.MIB->getNoRegister()) {
210     TagOS << "_seg_";
211     BC.InstPrinter->printRegName(TagOS, MemRef.SegRegNum);
212   }
213 
214   return Tag;
215 }
216 
217 BinaryFunction *RetpolineInsertion::getOrCreateRetpoline(
218     BinaryContext &BC, const IndirectBranchInfo &BrInfo, bool R11Available) {
219   const std::string RetpolineTag =
220       createRetpolineFunctionTag(BC, BrInfo, R11Available);
221 
222   if (CreatedRetpolines.count(RetpolineTag))
223     return CreatedRetpolines[RetpolineTag];
224 
225   return CreatedRetpolines[RetpolineTag] =
226              createNewRetpoline(BC, RetpolineTag, BrInfo, R11Available);
227 }
228 
229 void createBranchReplacement(BinaryContext &BC,
230                              const IndirectBranchInfo &BrInfo,
231                              bool R11Available,
232                              InstructionListType &Replacement,
233                              const MCSymbol *RetpolineSymbol) {
234   auto &MIB = *BC.MIB;
235   // Load the branch address in r11 if available
236   if (BrInfo.isMem() && R11Available) {
237     const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
238     MCInst LoadCalleeAddrs;
239     MIB.createLoad(LoadCalleeAddrs, MemRef.BaseRegNum, MemRef.ScaleImm,
240                    MemRef.IndexRegNum, MemRef.DispImm, MemRef.DispExpr,
241                    MemRef.SegRegNum, MIB.getX86R11(), 8);
242     Replacement.push_back(LoadCalleeAddrs);
243   }
244 
245   // Call the retpoline
246   MCInst RetpolineCall;
247   MIB.createDirectCall(RetpolineCall, RetpolineSymbol, BC.Ctx.get(),
248                        BrInfo.isJump() || BrInfo.isTailCall());
249 
250   Replacement.push_back(RetpolineCall);
251 }
252 
253 IndirectBranchInfo::IndirectBranchInfo(MCInst &Inst, MCPlusBuilder &MIB) {
254   IsCall = MIB.isCall(Inst);
255   IsTailCall = MIB.isTailCall(Inst);
256 
257   if (MIB.isBranchOnMem(Inst)) {
258     IsMem = true;
259     std::optional<MCPlusBuilder::X86MemOperand> MO =
260         MIB.evaluateX86MemoryOperand(Inst);
261     if (!MO)
262       llvm_unreachable("not expected");
263     Memory = MO.value();
264   } else if (MIB.isBranchOnReg(Inst)) {
265     assert(MCPlus::getNumPrimeOperands(Inst) == 1 && "expect 1 operand");
266     BranchReg = Inst.getOperand(0).getReg();
267   } else {
268     llvm_unreachable("unexpected instruction");
269   }
270 }
271 
272 Error RetpolineInsertion::runOnFunctions(BinaryContext &BC) {
273   if (!opts::InsertRetpolines)
274     return Error::success();
275 
276   assert(BC.isX86() &&
277          "retpoline insertion not supported for target architecture");
278 
279   assert(BC.HasRelocations && "retpoline mode not supported in non-reloc");
280 
281   auto &MIB = *BC.MIB;
282   uint32_t RetpolinedBranches = 0;
283   for (auto &It : BC.getBinaryFunctions()) {
284     BinaryFunction &Function = It.second;
285     for (BinaryBasicBlock &BB : Function) {
286       for (auto It = BB.begin(); It != BB.end(); ++It) {
287         MCInst &Inst = *It;
288 
289         if (!MIB.isIndirectCall(Inst) && !MIB.isIndirectBranch(Inst))
290           continue;
291 
292         IndirectBranchInfo BrInfo(Inst, MIB);
293         bool R11Available = false;
294         BinaryFunction *TargetRetpoline;
295         InstructionListType Replacement;
296 
297         // Determine if r11 is available before this instruction
298         if (BrInfo.isMem()) {
299           if (MIB.hasAnnotation(Inst, "PLTCall"))
300             R11Available = true;
301           else if (opts::R11Availability == AvailabilityOptions::ALWAYS)
302             R11Available = true;
303           else if (opts::R11Availability == AvailabilityOptions::ABI)
304             R11Available = BrInfo.isCall();
305         }
306 
307         // If the instruction addressing pattern uses rsp and the retpoline
308         // loads the callee address then displacement needs to be updated
309         if (BrInfo.isMem() && !R11Available) {
310           IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
311           int Addend = (BrInfo.isJump() || BrInfo.isTailCall()) ? 8 : 16;
312           if (MemRef.BaseRegNum == MIB.getStackPointer())
313             MemRef.DispImm += Addend;
314           if (MemRef.IndexRegNum == MIB.getStackPointer())
315             MemRef.DispImm += Addend * MemRef.ScaleImm;
316         }
317 
318         TargetRetpoline = getOrCreateRetpoline(BC, BrInfo, R11Available);
319 
320         createBranchReplacement(BC, BrInfo, R11Available, Replacement,
321                                 TargetRetpoline->getSymbol());
322 
323         It = BB.replaceInstruction(It, Replacement.begin(), Replacement.end());
324         RetpolinedBranches++;
325       }
326     }
327   }
328   BC.outs() << "BOLT-INFO: The number of created retpoline functions is : "
329             << CreatedRetpolines.size()
330             << "\nBOLT-INFO: The number of retpolined branches is : "
331             << RetpolinedBranches << "\n";
332   return Error::success();
333 }
334 
335 } // namespace bolt
336 } // namespace llvm
337