xref: /llvm-project/llvm/lib/CodeGen/StackProtector.cpp (revision d31fc54f3435d5e9b0998cc8590f3cd43f91ad38)
1 //===-- StackProtector.cpp - Stack Protector Insertion --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass inserts stack protectors into functions which need them. A variable
11 // with a random value in it is stored onto the stack before the local variables
12 // are allocated. Upon exiting the block, the stored value is checked. If it's
13 // changed, then there was some sort of violation and the program aborts.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #define DEBUG_TYPE "stack-protector"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/Constants.h"
20 #include "llvm/DerivedTypes.h"
21 #include "llvm/Function.h"
22 #include "llvm/Instructions.h"
23 #include "llvm/Module.h"
24 #include "llvm/Pass.h"
25 #include "llvm/ADT/APInt.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Target/TargetData.h"
28 #include "llvm/Target/TargetLowering.h"
29 using namespace llvm;
30 
31 // Enable stack protectors.
32 static cl::opt<unsigned>
33 SSPBufferSize("stack-protector-buffer-size", cl::init(8),
34               cl::desc("The lower bound for a buffer to be considered for "
35                        "stack smashing protection."));
36 
37 namespace {
38   class VISIBILITY_HIDDEN StackProtector : public FunctionPass {
39     /// Level - The level of stack protection.
40     SSP::StackProtectorLevel Level;
41 
42     /// TLI - Keep a pointer of a TargetLowering to consult for determining
43     /// target type sizes.
44     const TargetLowering *TLI;
45 
46     /// FailBB - Holds the basic block to jump to when the stack protector check
47     /// fails.
48     BasicBlock *FailBB;
49 
50     /// StackProtFrameSlot - The place on the stack that the stack protector
51     /// guard is kept.
52     AllocaInst *StackProtFrameSlot;
53 
54     /// StackGuardVar - The global variable for the stack guard.
55     Constant *StackGuardVar;
56 
57     Function *F;
58     Module *M;
59 
60     /// InsertStackProtectorPrologue - Insert code into the entry block that
61     /// stores the __stack_chk_guard variable onto the stack.
62     void InsertStackProtectorPrologue();
63 
64     /// InsertStackProtectorEpilogue - Insert code before the return
65     /// instructions checking the stack value that was stored in the
66     /// prologue. If it isn't the same as the original value, then call a
67     /// "failure" function.
68     void InsertStackProtectorEpilogue();
69 
70     /// CreateFailBB - Create a basic block to jump to when the stack protector
71     /// check fails.
72     void CreateFailBB();
73 
74     /// RequiresStackProtector - Check whether or not this function needs a
75     /// stack protector based upon the stack protector level.
76     bool RequiresStackProtector() const;
77   public:
78     static char ID;             // Pass identification, replacement for typeid.
79     StackProtector() : FunctionPass(&ID), Level(SSP::OFF), TLI(0), FailBB(0) {}
80     StackProtector(SSP::StackProtectorLevel lvl, const TargetLowering *tli)
81       : FunctionPass(&ID), Level(lvl), TLI(tli), FailBB(0) {}
82 
83     virtual bool runOnFunction(Function &Fn);
84   };
85 } // end anonymous namespace
86 
87 char StackProtector::ID = 0;
88 static RegisterPass<StackProtector>
89 X("stack-protector", "Insert stack protectors");
90 
91 FunctionPass *llvm::createStackProtectorPass(SSP::StackProtectorLevel lvl,
92                                              const TargetLowering *tli) {
93   return new StackProtector(lvl, tli);
94 }
95 
96 bool StackProtector::runOnFunction(Function &Fn) {
97   F = &Fn;
98   M = F->getParent();
99 
100   if (!RequiresStackProtector()) return false;
101 
102   InsertStackProtectorPrologue();
103   InsertStackProtectorEpilogue();
104 
105   // Cleanup.
106   FailBB = 0;
107   StackProtFrameSlot = 0;
108   StackGuardVar = 0;
109   return true;
110 }
111 
112 /// InsertStackProtectorPrologue - Insert code into the entry block that stores
113 /// the __stack_chk_guard variable onto the stack.
114 void StackProtector::InsertStackProtectorPrologue() {
115   BasicBlock &Entry = F->getEntryBlock();
116   Instruction &InsertPt = Entry.front();
117   const PointerType *GuardTy = PointerType::getUnqual(Type::Int8Ty);
118 
119   StackGuardVar = M->getOrInsertGlobal("__stack_chk_guard", GuardTy);
120   StackProtFrameSlot = new AllocaInst(GuardTy, "StackProt_Frame", &InsertPt);
121   LoadInst *LI = new LoadInst(StackGuardVar, "StackGuard", false, &InsertPt);
122   new StoreInst(LI, StackProtFrameSlot, false, &InsertPt);
123 }
124 
125 /// InsertStackProtectorEpilogue - Insert code before the return instructions
126 /// checking the stack value that was stored in the prologue. If it isn't the
127 /// same as the original value, then call a "failure" function.
128 void StackProtector::InsertStackProtectorEpilogue() {
129   // Create the basic block to jump to when the guard check fails.
130   CreateFailBB();
131 
132   Function::iterator I = F->begin(), E = F->end();
133   std::vector<BasicBlock*> ReturnBBs;
134   ReturnBBs.reserve(F->size());
135 
136   for (; I != E; ++I)
137     if (isa<ReturnInst>(I->getTerminator()))
138       ReturnBBs.push_back(I);
139 
140   if (ReturnBBs.empty()) return; // Odd, but could happen. . .
141 
142   // Loop through the basic blocks that have return instructions. Convert this:
143   //
144   //   return:
145   //     ...
146   //     ret ...
147   //
148   // into this:
149   //
150   //   return:
151   //     ...
152   //     %1 = load __stack_chk_guard
153   //     %2 = load <stored stack guard>
154   //     %3 = cmp i1 %1, %2
155   //     br i1 %3, label %SPRet, label %CallStackCheckFailBlk
156   //
157   //   SP_return:
158   //     ret ...
159   //
160   //   CallStackCheckFailBlk:
161   //     call void @__stack_chk_fail()
162   //     unreachable
163   //
164   for (std::vector<BasicBlock*>::iterator
165          II = ReturnBBs.begin(), IE = ReturnBBs.end(); II != IE; ++II) {
166     BasicBlock *BB = *II;
167     ReturnInst *RI = cast<ReturnInst>(BB->getTerminator());
168     Function::iterator InsPt = BB; ++InsPt; // Insertion point for new BB.
169 
170     // Split the basic block before the return instruction.
171     BasicBlock *NewBB = BB->splitBasicBlock(RI, "SP_return");
172 
173     // Move the newly created basic block to the point right after the old basic
174     // block.
175     NewBB->removeFromParent();
176     F->getBasicBlockList().insert(InsPt, NewBB);
177 
178     // Generate the stack protector instructions in the old basic block.
179     LoadInst *LI2 = new LoadInst(StackGuardVar, "", false, BB);
180     LoadInst *LI1 = new LoadInst(StackProtFrameSlot, "", true, BB);
181     ICmpInst *Cmp = new ICmpInst(CmpInst::ICMP_EQ, LI1, LI2, "", BB);
182     BranchInst::Create(NewBB, FailBB, Cmp, BB);
183   }
184 }
185 
186 /// CreateFailBB - Create a basic block to jump to when the stack protector
187 /// check fails.
188 void StackProtector::CreateFailBB() {
189   assert(!FailBB && "Failure basic block already created?!");
190   FailBB = BasicBlock::Create("CallStackCheckFailBlk", F);
191   std::vector<const Type*> Params;
192   Constant *StackChkFail =
193     M->getOrInsertFunction("__stack_chk_fail", Type::VoidTy, NULL);
194   CallInst::Create(StackChkFail, "", FailBB);
195   new UnreachableInst(FailBB);
196 }
197 
198 /// RequiresStackProtector - Check whether or not this function needs a stack
199 /// protector based upon the stack protector level.
200 bool StackProtector::RequiresStackProtector() const {
201   switch (Level) {
202   default: return false;
203   case SSP::ALL: return true;
204   case SSP::SOME: {
205     // If the size of the local variables allocated on the stack is greater than
206     // SSPBufferSize, then we require a stack protector.
207     uint64_t StackSize = 0;
208     const TargetData *TD = TLI->getTargetData();
209 
210     for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) {
211       BasicBlock *BB = I;
212 
213       for (BasicBlock::iterator
214              II = BB->begin(), IE = BB->end(); II != IE; ++II)
215         if (AllocaInst *AI = dyn_cast<AllocaInst>(II)) {
216           if (ConstantInt *CI = dyn_cast<ConstantInt>(AI->getArraySize())) {
217             uint64_t Bytes = TD->getTypeSizeInBits(AI->getAllocatedType()) / 8;
218             const APInt &Size = CI->getValue();
219             StackSize += Bytes * Size.getZExtValue();
220 
221             if (SSPBufferSize <= StackSize)
222               return true;
223           }
224         }
225     }
226 
227     return false;
228   }
229   }
230 }
231