xref: /openbsd-src/gnu/llvm/llvm/lib/CodeGen/ReturnProtectorLowering.cpp (revision d3a95e33d6d8bdc6ac076f4f3cceb472e8e04206)
1 //===- ReturnProtectorLowering.cpp - ---------------------------------------==//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Implements common routines for return protector support.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/ReturnProtectorLowering.h"
15 #include "llvm/ADT/SmallSet.h"
16 #include "llvm/CodeGen/MachineFrameInfo.h"
17 #include "llvm/CodeGen/MachineFunction.h"
18 #include "llvm/CodeGen/MachineRegisterInfo.h"
19 #include "llvm/CodeGen/TargetFrameLowering.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/MC/MCRegisterInfo.h"
24 #include "llvm/Target/TargetMachine.h"
25 #include "llvm/Target/TargetOptions.h"
26 
27 using namespace llvm;
28 
markUsedRegsInSuccessors(MachineBasicBlock & MBB,SmallSet<unsigned,16> & Used,SmallSet<int,24> & Visited)29 static void markUsedRegsInSuccessors(MachineBasicBlock &MBB,
30                                      SmallSet<unsigned, 16> &Used,
31                                      SmallSet<int, 24> &Visited) {
32   int BBNum = MBB.getNumber();
33   if (Visited.count(BBNum))
34     return;
35 
36   // Mark all the registers used
37   for (auto &MBBI : MBB.instrs()) {
38     for (auto &MBBIOp : MBBI.operands()) {
39       if (MBBIOp.isReg())
40         Used.insert(MBBIOp.getReg());
41     }
42   }
43 
44   // Mark this MBB as visited
45   Visited.insert(BBNum);
46   // Recurse over all successors
47   for (auto &SuccMBB : MBB.successors())
48     markUsedRegsInSuccessors(*SuccMBB, Used, Visited);
49 }
50 
containsProtectableData(Type * Ty)51 static bool containsProtectableData(Type *Ty) {
52   if (!Ty)
53     return false;
54 
55   if (ArrayType *AT = dyn_cast<ArrayType>(Ty))
56     return true;
57 
58   if (StructType *ST = dyn_cast<StructType>(Ty)) {
59     for (StructType::element_iterator I = ST->element_begin(),
60                                       E = ST->element_end();
61          I != E; ++I) {
62       if (containsProtectableData(*I))
63         return true;
64     }
65   }
66   return false;
67 }
68 
69 // Mostly the same as StackProtector::HasAddressTaken
hasAddressTaken(const Instruction * AI,SmallPtrSet<const PHINode *,16> & visitedPHI)70 static bool hasAddressTaken(const Instruction *AI,
71                             SmallPtrSet<const PHINode *, 16> &visitedPHI) {
72   for (const User *U : AI->users()) {
73     const auto *I = cast<Instruction>(U);
74     switch (I->getOpcode()) {
75     case Instruction::Store:
76       if (AI == cast<StoreInst>(I)->getValueOperand())
77         return true;
78       break;
79     case Instruction::AtomicCmpXchg:
80       if (AI == cast<AtomicCmpXchgInst>(I)->getNewValOperand())
81         return true;
82       break;
83     case Instruction::PtrToInt:
84       if (AI == cast<PtrToIntInst>(I)->getOperand(0))
85         return true;
86       break;
87     case Instruction::BitCast:
88     case Instruction::GetElementPtr:
89     case Instruction::Select:
90     case Instruction::AddrSpaceCast:
91       if (hasAddressTaken(I, visitedPHI))
92         return true;
93       break;
94     case Instruction::PHI: {
95       const auto *PN = cast<PHINode>(I);
96       if (visitedPHI.insert(PN).second)
97         if (hasAddressTaken(PN, visitedPHI))
98           return true;
99       break;
100     }
101     case Instruction::Load:
102     case Instruction::AtomicRMW:
103     case Instruction::Ret:
104       return false;
105       break;
106     default:
107       // Conservatively return true for any instruction that takes an address
108       // operand, but is not handled above.
109       return true;
110     }
111   }
112   return false;
113 }
114 
115 /// setupReturnProtector - Checks the function for ROP friendly return
116 /// instructions and sets ReturnProtectorNeeded if found.
setupReturnProtector(MachineFunction & MF) const117 void ReturnProtectorLowering::setupReturnProtector(MachineFunction &MF) const {
118   if (MF.getFunction().hasFnAttribute("ret-protector")) {
119     for (auto &MBB : MF) {
120       for (auto &T : MBB.terminators()) {
121         if (opcodeIsReturn(T.getOpcode())) {
122           MF.getFrameInfo().setReturnProtectorNeeded(true);
123           return;
124         }
125       }
126     }
127   }
128 }
129 
130 /// saveReturnProtectorRegister - Allows the target to save the
131 /// ReturnProtectorRegister in the CalleeSavedInfo vector if needed.
saveReturnProtectorRegister(MachineFunction & MF,std::vector<CalleeSavedInfo> & CSI) const132 void ReturnProtectorLowering::saveReturnProtectorRegister(
133     MachineFunction &MF, std::vector<CalleeSavedInfo> &CSI) const {
134   const MachineFrameInfo &MFI = MF.getFrameInfo();
135   if (!MFI.getReturnProtectorNeeded())
136     return;
137 
138   if (!MFI.hasReturnProtectorRegister())
139     llvm_unreachable("Saving unset return protector register");
140 
141   unsigned Reg = MFI.getReturnProtectorRegister();
142   if (MFI.getReturnProtectorNeedsStore())
143     CSI.push_back(CalleeSavedInfo(Reg));
144   else {
145     for (auto &MBB : MF) {
146       if (!MBB.isLiveIn(Reg))
147         MBB.addLiveIn(Reg);
148     }
149   }
150 }
151 
152 /// determineReturnProtectorTempRegister - Find a register that can be used
153 /// during function prologue / epilogue to store the return protector cookie.
154 /// Returns false if a register is needed but could not be found,
155 /// otherwise returns true.
determineReturnProtectorRegister(MachineFunction & MF,const SmallVector<MachineBasicBlock *,4> & SaveBlocks,const SmallVector<MachineBasicBlock *,4> & RestoreBlocks) const156 bool ReturnProtectorLowering::determineReturnProtectorRegister(
157     MachineFunction &MF, const SmallVector<MachineBasicBlock *, 4> &SaveBlocks,
158     const SmallVector<MachineBasicBlock *, 4> &RestoreBlocks) const {
159   MachineFrameInfo &MFI = MF.getFrameInfo();
160   if (!MFI.getReturnProtectorNeeded())
161     return true;
162 
163   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
164 
165   std::vector<unsigned> TempRegs;
166   fillTempRegisters(MF, TempRegs);
167 
168   // For leaf functions, try to find a free register that is available
169   // in every BB, so we do not need to store it in the frame at all.
170   // We walk the entire function here because MFI.hasCalls() is unreliable.
171   bool hasCalls = false;
172   for (auto &MBB : MF) {
173     for (auto &MI : MBB) {
174       if (MI.isCall() && !MI.isReturn()) {
175         hasCalls = true;
176         break;
177       }
178     }
179     if (hasCalls)
180       break;
181   }
182 
183   // If the return address is always on the stack, then we
184   // want to try to keep the return protector cookie unspilled.
185   // This prevents a single stack smash from corrupting both the
186   // return protector cookie and the return address.
187   llvm::Triple::ArchType arch = MF.getTarget().getTargetTriple().getArch();
188   bool returnAddrOnStack = arch == llvm::Triple::ArchType::x86
189                         || arch == llvm::Triple::ArchType::x86_64;
190 
191   // For architectures which do not spill a return address
192   // to the stack by default, it is possible that in a leaf
193   // function that neither the return address or the retguard cookie
194   // will be spilled, and stack corruption may be missed.
195   // Here, we check leaf functions on these kinds of architectures
196   // to see if they have any variable sized local allocations,
197   // array type allocations, allocations which contain array
198   // types, or elements that have their address taken. If any of
199   // these conditions are met, then we skip leaf function
200   // optimization and spill the retguard cookie to the stack.
201   bool hasLocals = MFI.hasVarSizedObjects();
202   if (!hasCalls && !hasLocals && !returnAddrOnStack) {
203     for (const BasicBlock &BB : MF.getFunction()) {
204       for (const Instruction &I : BB) {
205         if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
206           // Check for array allocations
207           Type *Ty = AI->getAllocatedType();
208           if (AI->isArrayAllocation() || containsProtectableData(Ty)) {
209             hasLocals = true;
210             break;
211           }
212           // Check for address taken
213           SmallPtrSet<const PHINode *, 16> visitedPHIs;
214           if (hasAddressTaken(AI, visitedPHIs)) {
215             hasLocals = true;
216             break;
217           }
218         }
219       }
220       if (hasLocals)
221         break;
222     }
223   }
224 
225   bool tryLeafOptimize = !hasCalls && (returnAddrOnStack || !hasLocals);
226 
227   if (tryLeafOptimize) {
228     SmallSet<unsigned, 16> LeafUsed;
229     SmallSet<int, 24> LeafVisited;
230     markUsedRegsInSuccessors(MF.front(), LeafUsed, LeafVisited);
231     for (unsigned Reg : TempRegs) {
232       bool canUse = true;
233       for (MCRegAliasIterator AI(Reg, TRI, true); AI.isValid(); ++AI) {
234         if (LeafUsed.count(*AI)) {
235           canUse = false;
236           break;
237         }
238       }
239       if (canUse) {
240         MFI.setReturnProtectorRegister(Reg);
241         MFI.setReturnProtectorNeedsStore(false);
242         return true;
243       }
244     }
245   }
246 
247   // For non-leaf functions, we only need to search save / restore blocks
248   SmallSet<unsigned, 16> Used;
249   SmallSet<int, 24> Visited;
250 
251   // CSR spills happen at the beginning of this block
252   // so we can mark it as visited because anything past it is safe
253   for (auto &SB : SaveBlocks)
254     Visited.insert(SB->getNumber());
255 
256   // CSR Restores happen at the end of restore blocks, before any terminators,
257   // so we need to search restores for MBB terminators, and any successor BBs.
258   for (auto &RB : RestoreBlocks) {
259     for (auto &RBI : RB->terminators()) {
260       for (auto &RBIOp : RBI.operands()) {
261         if (RBIOp.isReg())
262           Used.insert(RBIOp.getReg());
263       }
264     }
265     for (auto &SuccMBB : RB->successors())
266       markUsedRegsInSuccessors(*SuccMBB, Used, Visited);
267   }
268 
269   // Now we iterate from the front to find code paths that
270   // bypass save blocks and land on return blocks
271   markUsedRegsInSuccessors(MF.front(), Used, Visited);
272 
273   // Now we have gathered all the regs used outside the frame save / restore,
274   // so we can see if we have a free reg to use for the retguard cookie.
275   for (unsigned Reg : TempRegs) {
276     bool canUse = true;
277     for (MCRegAliasIterator AI(Reg, TRI, true); AI.isValid(); ++AI) {
278       if (Used.count(*AI)) {
279         // Reg is used somewhere, so we cannot use it
280         canUse = false;
281         break;
282       }
283     }
284     if (canUse) {
285       MFI.setReturnProtectorRegister(Reg);
286       break;
287     }
288   }
289 
290   return MFI.hasReturnProtectorRegister();
291 }
292 
293 /// insertReturnProtectors - insert return protector instrumentation.
insertReturnProtectors(MachineFunction & MF) const294 void ReturnProtectorLowering::insertReturnProtectors(
295     MachineFunction &MF) const {
296   MachineFrameInfo &MFI = MF.getFrameInfo();
297 
298   if (!MFI.getReturnProtectorNeeded())
299     return;
300 
301   if (!MFI.hasReturnProtectorRegister())
302     llvm_unreachable("Inconsistent return protector state.");
303 
304   const Function &Fn = MF.getFunction();
305   const Module *M = Fn.getParent();
306   GlobalVariable *cookie =
307       dyn_cast_or_null<GlobalVariable>(M->getGlobalVariable(
308           Fn.getFnAttribute("ret-protector-cookie").getValueAsString(),
309           Type::getInt8PtrTy(M->getContext())));
310 
311   if (!cookie)
312     llvm_unreachable("Function needs return protector but no cookie assigned");
313 
314   unsigned Reg = MFI.getReturnProtectorRegister();
315 
316   std::vector<MachineInstr *> returns;
317   for (auto &MBB : MF) {
318     if (MBB.isReturnBlock()) {
319       for (auto &MI : MBB.terminators()) {
320         if (opcodeIsReturn(MI.getOpcode())) {
321           returns.push_back(&MI);
322           if (!MBB.isLiveIn(Reg))
323             MBB.addLiveIn(Reg);
324         }
325       }
326     }
327   }
328 
329   if (returns.empty())
330     return;
331 
332   for (auto &MI : returns)
333     insertReturnProtectorEpilogue(MF, *MI, cookie);
334 
335   insertReturnProtectorPrologue(MF, MF.front(), cookie);
336 
337   if (!MF.front().isLiveIn(Reg))
338     MF.front().addLiveIn(Reg);
339 }
340