xref: /llvm-project/llvm/lib/Target/X86/X86WinEHState.cpp (revision 14359ef1b6a0610ac91df5f5a91c88a0b51c187c)
1 //===-- X86WinEHState - Insert EH state updates for win32 exceptions ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // All functions using an MSVC EH personality use an explicitly updated state
10 // number stored in an exception registration stack object. The registration
11 // object is linked into a thread-local chain of registrations stored at fs:00.
12 // This pass adds the registration object and EH state updates.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "X86.h"
17 #include "llvm/ADT/PostOrderIterator.h"
18 #include "llvm/Analysis/CFG.h"
19 #include "llvm/Analysis/EHPersonalities.h"
20 #include "llvm/CodeGen/MachineModuleInfo.h"
21 #include "llvm/CodeGen/WinEHFuncInfo.h"
22 #include "llvm/IR/CallSite.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Instructions.h"
26 #include "llvm/IR/IntrinsicInst.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/Debug.h"
30 #include <deque>
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "winehstate"
35 
36 namespace {
37 const int OverdefinedState = INT_MIN;
38 
39 class WinEHStatePass : public FunctionPass {
40 public:
41   static char ID; // Pass identification, replacement for typeid.
42 
43   WinEHStatePass() : FunctionPass(ID) {
44     initializeWinEHStatePassPass(*PassRegistry::getPassRegistry());
45   }
46 
47   bool runOnFunction(Function &Fn) override;
48 
49   bool doInitialization(Module &M) override;
50 
51   bool doFinalization(Module &M) override;
52 
53   void getAnalysisUsage(AnalysisUsage &AU) const override;
54 
55   StringRef getPassName() const override {
56     return "Windows 32-bit x86 EH state insertion";
57   }
58 
59 private:
60   void emitExceptionRegistrationRecord(Function *F);
61 
62   void linkExceptionRegistration(IRBuilder<> &Builder, Function *Handler);
63   void unlinkExceptionRegistration(IRBuilder<> &Builder);
64   void addStateStores(Function &F, WinEHFuncInfo &FuncInfo);
65   void insertStateNumberStore(Instruction *IP, int State);
66 
67   Value *emitEHLSDA(IRBuilder<> &Builder, Function *F);
68 
69   Function *generateLSDAInEAXThunk(Function *ParentFunc);
70 
71   bool isStateStoreNeeded(EHPersonality Personality, CallSite CS);
72   void rewriteSetJmpCallSite(IRBuilder<> &Builder, Function &F, CallSite CS,
73                              Value *State);
74   int getBaseStateForBB(DenseMap<BasicBlock *, ColorVector> &BlockColors,
75                         WinEHFuncInfo &FuncInfo, BasicBlock *BB);
76   int getStateForCallSite(DenseMap<BasicBlock *, ColorVector> &BlockColors,
77                           WinEHFuncInfo &FuncInfo, CallSite CS);
78 
79   // Module-level type getters.
80   Type *getEHLinkRegistrationType();
81   Type *getSEHRegistrationType();
82   Type *getCXXEHRegistrationType();
83 
84   // Per-module data.
85   Module *TheModule = nullptr;
86   StructType *EHLinkRegistrationTy = nullptr;
87   StructType *CXXEHRegistrationTy = nullptr;
88   StructType *SEHRegistrationTy = nullptr;
89   FunctionCallee SetJmp3 = nullptr;
90   FunctionCallee CxxLongjmpUnwind = nullptr;
91 
92   // Per-function state
93   EHPersonality Personality = EHPersonality::Unknown;
94   Function *PersonalityFn = nullptr;
95   bool UseStackGuard = false;
96   int ParentBaseState;
97   FunctionCallee SehLongjmpUnwind = nullptr;
98   Constant *Cookie = nullptr;
99 
100   /// The stack allocation containing all EH data, including the link in the
101   /// fs:00 chain and the current state.
102   AllocaInst *RegNode = nullptr;
103 
104   // The allocation containing the EH security guard.
105   AllocaInst *EHGuardNode = nullptr;
106 
107   /// The index of the state field of RegNode.
108   int StateFieldIndex = ~0U;
109 
110   /// The linked list node subobject inside of RegNode.
111   Value *Link = nullptr;
112 };
113 }
114 
115 FunctionPass *llvm::createX86WinEHStatePass() { return new WinEHStatePass(); }
116 
117 char WinEHStatePass::ID = 0;
118 
119 INITIALIZE_PASS(WinEHStatePass, "x86-winehstate",
120                 "Insert stores for EH state numbers", false, false)
121 
122 bool WinEHStatePass::doInitialization(Module &M) {
123   TheModule = &M;
124   return false;
125 }
126 
127 bool WinEHStatePass::doFinalization(Module &M) {
128   assert(TheModule == &M);
129   TheModule = nullptr;
130   EHLinkRegistrationTy = nullptr;
131   CXXEHRegistrationTy = nullptr;
132   SEHRegistrationTy = nullptr;
133   SetJmp3 = nullptr;
134   CxxLongjmpUnwind = nullptr;
135   SehLongjmpUnwind = nullptr;
136   Cookie = nullptr;
137   return false;
138 }
139 
140 void WinEHStatePass::getAnalysisUsage(AnalysisUsage &AU) const {
141   // This pass should only insert a stack allocation, memory accesses, and
142   // localrecovers.
143   AU.setPreservesCFG();
144 }
145 
146 bool WinEHStatePass::runOnFunction(Function &F) {
147   // Don't insert state stores or exception handler thunks for
148   // available_externally functions. The handler needs to reference the LSDA,
149   // which will not be emitted in this case.
150   if (F.hasAvailableExternallyLinkage())
151     return false;
152 
153   // Check the personality. Do nothing if this personality doesn't use funclets.
154   if (!F.hasPersonalityFn())
155     return false;
156   PersonalityFn =
157       dyn_cast<Function>(F.getPersonalityFn()->stripPointerCasts());
158   if (!PersonalityFn)
159     return false;
160   Personality = classifyEHPersonality(PersonalityFn);
161   if (!isFuncletEHPersonality(Personality))
162     return false;
163 
164   // Skip this function if there are no EH pads and we aren't using IR-level
165   // outlining.
166   bool HasPads = false;
167   for (BasicBlock &BB : F) {
168     if (BB.isEHPad()) {
169       HasPads = true;
170       break;
171     }
172   }
173   if (!HasPads)
174     return false;
175 
176   Type *Int8PtrType = Type::getInt8PtrTy(TheModule->getContext());
177   SetJmp3 = TheModule->getOrInsertFunction(
178       "_setjmp3", FunctionType::get(
179                       Type::getInt32Ty(TheModule->getContext()),
180                       {Int8PtrType, Type::getInt32Ty(TheModule->getContext())},
181                       /*isVarArg=*/true));
182 
183   // Disable frame pointer elimination in this function.
184   // FIXME: Do the nested handlers need to keep the parent ebp in ebp, or can we
185   // use an arbitrary register?
186   F.addFnAttr("no-frame-pointer-elim", "true");
187 
188   emitExceptionRegistrationRecord(&F);
189 
190   // The state numbers calculated here in IR must agree with what we calculate
191   // later on for the MachineFunction. In particular, if an IR pass deletes an
192   // unreachable EH pad after this point before machine CFG construction, we
193   // will be in trouble. If this assumption is ever broken, we should turn the
194   // numbers into an immutable analysis pass.
195   WinEHFuncInfo FuncInfo;
196   addStateStores(F, FuncInfo);
197 
198   // Reset per-function state.
199   PersonalityFn = nullptr;
200   Personality = EHPersonality::Unknown;
201   UseStackGuard = false;
202   RegNode = nullptr;
203   EHGuardNode = nullptr;
204 
205   return true;
206 }
207 
208 /// Get the common EH registration subobject:
209 ///   typedef _EXCEPTION_DISPOSITION (*PEXCEPTION_ROUTINE)(
210 ///       _EXCEPTION_RECORD *, void *, _CONTEXT *, void *);
211 ///   struct EHRegistrationNode {
212 ///     EHRegistrationNode *Next;
213 ///     PEXCEPTION_ROUTINE Handler;
214 ///   };
215 Type *WinEHStatePass::getEHLinkRegistrationType() {
216   if (EHLinkRegistrationTy)
217     return EHLinkRegistrationTy;
218   LLVMContext &Context = TheModule->getContext();
219   EHLinkRegistrationTy = StructType::create(Context, "EHRegistrationNode");
220   Type *FieldTys[] = {
221       EHLinkRegistrationTy->getPointerTo(0), // EHRegistrationNode *Next
222       Type::getInt8PtrTy(Context) // EXCEPTION_DISPOSITION (*Handler)(...)
223   };
224   EHLinkRegistrationTy->setBody(FieldTys, false);
225   return EHLinkRegistrationTy;
226 }
227 
228 /// The __CxxFrameHandler3 registration node:
229 ///   struct CXXExceptionRegistration {
230 ///     void *SavedESP;
231 ///     EHRegistrationNode SubRecord;
232 ///     int32_t TryLevel;
233 ///   };
234 Type *WinEHStatePass::getCXXEHRegistrationType() {
235   if (CXXEHRegistrationTy)
236     return CXXEHRegistrationTy;
237   LLVMContext &Context = TheModule->getContext();
238   Type *FieldTys[] = {
239       Type::getInt8PtrTy(Context), // void *SavedESP
240       getEHLinkRegistrationType(), // EHRegistrationNode SubRecord
241       Type::getInt32Ty(Context)    // int32_t TryLevel
242   };
243   CXXEHRegistrationTy =
244       StructType::create(FieldTys, "CXXExceptionRegistration");
245   return CXXEHRegistrationTy;
246 }
247 
248 /// The _except_handler3/4 registration node:
249 ///   struct EH4ExceptionRegistration {
250 ///     void *SavedESP;
251 ///     _EXCEPTION_POINTERS *ExceptionPointers;
252 ///     EHRegistrationNode SubRecord;
253 ///     int32_t EncodedScopeTable;
254 ///     int32_t TryLevel;
255 ///   };
256 Type *WinEHStatePass::getSEHRegistrationType() {
257   if (SEHRegistrationTy)
258     return SEHRegistrationTy;
259   LLVMContext &Context = TheModule->getContext();
260   Type *FieldTys[] = {
261       Type::getInt8PtrTy(Context), // void *SavedESP
262       Type::getInt8PtrTy(Context), // void *ExceptionPointers
263       getEHLinkRegistrationType(), // EHRegistrationNode SubRecord
264       Type::getInt32Ty(Context),   // int32_t EncodedScopeTable
265       Type::getInt32Ty(Context)    // int32_t TryLevel
266   };
267   SEHRegistrationTy = StructType::create(FieldTys, "SEHExceptionRegistration");
268   return SEHRegistrationTy;
269 }
270 
271 // Emit an exception registration record. These are stack allocations with the
272 // common subobject of two pointers: the previous registration record (the old
273 // fs:00) and the personality function for the current frame. The data before
274 // and after that is personality function specific.
275 void WinEHStatePass::emitExceptionRegistrationRecord(Function *F) {
276   assert(Personality == EHPersonality::MSVC_CXX ||
277          Personality == EHPersonality::MSVC_X86SEH);
278 
279   // Struct type of RegNode. Used for GEPing.
280   Type *RegNodeTy;
281 
282   IRBuilder<> Builder(&F->getEntryBlock(), F->getEntryBlock().begin());
283   Type *Int8PtrType = Builder.getInt8PtrTy();
284   Type *Int32Ty = Builder.getInt32Ty();
285   Type *VoidTy = Builder.getVoidTy();
286 
287   if (Personality == EHPersonality::MSVC_CXX) {
288     RegNodeTy = getCXXEHRegistrationType();
289     RegNode = Builder.CreateAlloca(RegNodeTy);
290     // SavedESP = llvm.stacksave()
291     Value *SP = Builder.CreateCall(
292         Intrinsic::getDeclaration(TheModule, Intrinsic::stacksave), {});
293     Builder.CreateStore(SP, Builder.CreateStructGEP(RegNodeTy, RegNode, 0));
294     // TryLevel = -1
295     StateFieldIndex = 2;
296     ParentBaseState = -1;
297     insertStateNumberStore(&*Builder.GetInsertPoint(), ParentBaseState);
298     // Handler = __ehhandler$F
299     Function *Trampoline = generateLSDAInEAXThunk(F);
300     Link = Builder.CreateStructGEP(RegNodeTy, RegNode, 1);
301     linkExceptionRegistration(Builder, Trampoline);
302 
303     CxxLongjmpUnwind = TheModule->getOrInsertFunction(
304         "__CxxLongjmpUnwind",
305         FunctionType::get(VoidTy, Int8PtrType, /*isVarArg=*/false));
306     cast<Function>(CxxLongjmpUnwind.getCallee()->stripPointerCasts())
307         ->setCallingConv(CallingConv::X86_StdCall);
308   } else if (Personality == EHPersonality::MSVC_X86SEH) {
309     // If _except_handler4 is in use, some additional guard checks and prologue
310     // stuff is required.
311     StringRef PersonalityName = PersonalityFn->getName();
312     UseStackGuard = (PersonalityName == "_except_handler4");
313 
314     // Allocate local structures.
315     RegNodeTy = getSEHRegistrationType();
316     RegNode = Builder.CreateAlloca(RegNodeTy);
317     if (UseStackGuard)
318       EHGuardNode = Builder.CreateAlloca(Int32Ty);
319 
320     // SavedESP = llvm.stacksave()
321     Value *SP = Builder.CreateCall(
322         Intrinsic::getDeclaration(TheModule, Intrinsic::stacksave), {});
323     Builder.CreateStore(SP, Builder.CreateStructGEP(RegNodeTy, RegNode, 0));
324     // TryLevel = -2 / -1
325     StateFieldIndex = 4;
326     ParentBaseState = UseStackGuard ? -2 : -1;
327     insertStateNumberStore(&*Builder.GetInsertPoint(), ParentBaseState);
328     // ScopeTable = llvm.x86.seh.lsda(F)
329     Value *LSDA = emitEHLSDA(Builder, F);
330     LSDA = Builder.CreatePtrToInt(LSDA, Int32Ty);
331     // If using _except_handler4, xor the address of the table with
332     // __security_cookie.
333     if (UseStackGuard) {
334       Cookie = TheModule->getOrInsertGlobal("__security_cookie", Int32Ty);
335       Value *Val = Builder.CreateLoad(Int32Ty, Cookie, "cookie");
336       LSDA = Builder.CreateXor(LSDA, Val);
337     }
338     Builder.CreateStore(LSDA, Builder.CreateStructGEP(RegNodeTy, RegNode, 3));
339 
340     // If using _except_handler4, the EHGuard contains: FramePtr xor Cookie.
341     if (UseStackGuard) {
342       Value *Val = Builder.CreateLoad(Int32Ty, Cookie);
343       Value *FrameAddr = Builder.CreateCall(
344           Intrinsic::getDeclaration(TheModule, Intrinsic::frameaddress),
345           Builder.getInt32(0), "frameaddr");
346       Value *FrameAddrI32 = Builder.CreatePtrToInt(FrameAddr, Int32Ty);
347       FrameAddrI32 = Builder.CreateXor(FrameAddrI32, Val);
348       Builder.CreateStore(FrameAddrI32, EHGuardNode);
349     }
350 
351     // Register the exception handler.
352     Link = Builder.CreateStructGEP(RegNodeTy, RegNode, 2);
353     linkExceptionRegistration(Builder, PersonalityFn);
354 
355     SehLongjmpUnwind = TheModule->getOrInsertFunction(
356         UseStackGuard ? "_seh_longjmp_unwind4" : "_seh_longjmp_unwind",
357         FunctionType::get(Type::getVoidTy(TheModule->getContext()), Int8PtrType,
358                           /*isVarArg=*/false));
359     cast<Function>(SehLongjmpUnwind.getCallee()->stripPointerCasts())
360         ->setCallingConv(CallingConv::X86_StdCall);
361   } else {
362     llvm_unreachable("unexpected personality function");
363   }
364 
365   // Insert an unlink before all returns.
366   for (BasicBlock &BB : *F) {
367     Instruction *T = BB.getTerminator();
368     if (!isa<ReturnInst>(T))
369       continue;
370     Builder.SetInsertPoint(T);
371     unlinkExceptionRegistration(Builder);
372   }
373 }
374 
375 Value *WinEHStatePass::emitEHLSDA(IRBuilder<> &Builder, Function *F) {
376   Value *FI8 = Builder.CreateBitCast(F, Type::getInt8PtrTy(F->getContext()));
377   return Builder.CreateCall(
378       Intrinsic::getDeclaration(TheModule, Intrinsic::x86_seh_lsda), FI8);
379 }
380 
381 /// Generate a thunk that puts the LSDA of ParentFunc in EAX and then calls
382 /// PersonalityFn, forwarding the parameters passed to PEXCEPTION_ROUTINE:
383 ///   typedef _EXCEPTION_DISPOSITION (*PEXCEPTION_ROUTINE)(
384 ///       _EXCEPTION_RECORD *, void *, _CONTEXT *, void *);
385 /// We essentially want this code:
386 ///   movl $lsda, %eax
387 ///   jmpl ___CxxFrameHandler3
388 Function *WinEHStatePass::generateLSDAInEAXThunk(Function *ParentFunc) {
389   LLVMContext &Context = ParentFunc->getContext();
390   Type *Int32Ty = Type::getInt32Ty(Context);
391   Type *Int8PtrType = Type::getInt8PtrTy(Context);
392   Type *ArgTys[5] = {Int8PtrType, Int8PtrType, Int8PtrType, Int8PtrType,
393                      Int8PtrType};
394   FunctionType *TrampolineTy =
395       FunctionType::get(Int32Ty, makeArrayRef(&ArgTys[0], 4),
396                         /*isVarArg=*/false);
397   FunctionType *TargetFuncTy =
398       FunctionType::get(Int32Ty, makeArrayRef(&ArgTys[0], 5),
399                         /*isVarArg=*/false);
400   Function *Trampoline =
401       Function::Create(TrampolineTy, GlobalValue::InternalLinkage,
402                        Twine("__ehhandler$") + GlobalValue::dropLLVMManglingEscape(
403                                                    ParentFunc->getName()),
404                        TheModule);
405   if (auto *C = ParentFunc->getComdat())
406     Trampoline->setComdat(C);
407   BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", Trampoline);
408   IRBuilder<> Builder(EntryBB);
409   Value *LSDA = emitEHLSDA(Builder, ParentFunc);
410   Value *CastPersonality =
411       Builder.CreateBitCast(PersonalityFn, TargetFuncTy->getPointerTo());
412   auto AI = Trampoline->arg_begin();
413   Value *Args[5] = {LSDA, &*AI++, &*AI++, &*AI++, &*AI++};
414   CallInst *Call = Builder.CreateCall(TargetFuncTy, CastPersonality, Args);
415   // Can't use musttail due to prototype mismatch, but we can use tail.
416   Call->setTailCall(true);
417   // Set inreg so we pass it in EAX.
418   Call->addParamAttr(0, Attribute::InReg);
419   Builder.CreateRet(Call);
420   return Trampoline;
421 }
422 
423 void WinEHStatePass::linkExceptionRegistration(IRBuilder<> &Builder,
424                                                Function *Handler) {
425   // Emit the .safeseh directive for this function.
426   Handler->addFnAttr("safeseh");
427 
428   Type *LinkTy = getEHLinkRegistrationType();
429   // Handler = Handler
430   Value *HandlerI8 = Builder.CreateBitCast(Handler, Builder.getInt8PtrTy());
431   Builder.CreateStore(HandlerI8, Builder.CreateStructGEP(LinkTy, Link, 1));
432   // Next = [fs:00]
433   Constant *FSZero =
434       Constant::getNullValue(LinkTy->getPointerTo()->getPointerTo(257));
435   Value *Next = Builder.CreateLoad(LinkTy->getPointerTo(), FSZero);
436   Builder.CreateStore(Next, Builder.CreateStructGEP(LinkTy, Link, 0));
437   // [fs:00] = Link
438   Builder.CreateStore(Link, FSZero);
439 }
440 
441 void WinEHStatePass::unlinkExceptionRegistration(IRBuilder<> &Builder) {
442   // Clone Link into the current BB for better address mode folding.
443   if (auto *GEP = dyn_cast<GetElementPtrInst>(Link)) {
444     GEP = cast<GetElementPtrInst>(GEP->clone());
445     Builder.Insert(GEP);
446     Link = GEP;
447   }
448   Type *LinkTy = getEHLinkRegistrationType();
449   // [fs:00] = Link->Next
450   Value *Next = Builder.CreateLoad(LinkTy->getPointerTo(),
451                                    Builder.CreateStructGEP(LinkTy, Link, 0));
452   Constant *FSZero =
453       Constant::getNullValue(LinkTy->getPointerTo()->getPointerTo(257));
454   Builder.CreateStore(Next, FSZero);
455 }
456 
457 // Calls to setjmp(p) are lowered to _setjmp3(p, 0) by the frontend.
458 // The idea behind _setjmp3 is that it takes an optional number of personality
459 // specific parameters to indicate how to restore the personality-specific frame
460 // state when longjmp is initiated.  Typically, the current TryLevel is saved.
461 void WinEHStatePass::rewriteSetJmpCallSite(IRBuilder<> &Builder, Function &F,
462                                            CallSite CS, Value *State) {
463   // Don't rewrite calls with a weird number of arguments.
464   if (CS.getNumArgOperands() != 2)
465     return;
466 
467   Instruction *Inst = CS.getInstruction();
468 
469   SmallVector<OperandBundleDef, 1> OpBundles;
470   CS.getOperandBundlesAsDefs(OpBundles);
471 
472   SmallVector<Value *, 3> OptionalArgs;
473   if (Personality == EHPersonality::MSVC_CXX) {
474     OptionalArgs.push_back(CxxLongjmpUnwind.getCallee());
475     OptionalArgs.push_back(State);
476     OptionalArgs.push_back(emitEHLSDA(Builder, &F));
477   } else if (Personality == EHPersonality::MSVC_X86SEH) {
478     OptionalArgs.push_back(SehLongjmpUnwind.getCallee());
479     OptionalArgs.push_back(State);
480     if (UseStackGuard)
481       OptionalArgs.push_back(Cookie);
482   } else {
483     llvm_unreachable("unhandled personality!");
484   }
485 
486   SmallVector<Value *, 5> Args;
487   Args.push_back(
488       Builder.CreateBitCast(CS.getArgOperand(0), Builder.getInt8PtrTy()));
489   Args.push_back(Builder.getInt32(OptionalArgs.size()));
490   Args.append(OptionalArgs.begin(), OptionalArgs.end());
491 
492   CallSite NewCS;
493   if (CS.isCall()) {
494     auto *CI = cast<CallInst>(Inst);
495     CallInst *NewCI = Builder.CreateCall(SetJmp3, Args, OpBundles);
496     NewCI->setTailCallKind(CI->getTailCallKind());
497     NewCS = NewCI;
498   } else {
499     auto *II = cast<InvokeInst>(Inst);
500     NewCS = Builder.CreateInvoke(
501         SetJmp3, II->getNormalDest(), II->getUnwindDest(), Args, OpBundles);
502   }
503   NewCS.setCallingConv(CS.getCallingConv());
504   NewCS.setAttributes(CS.getAttributes());
505   NewCS->setDebugLoc(CS->getDebugLoc());
506 
507   Instruction *NewInst = NewCS.getInstruction();
508   NewInst->takeName(Inst);
509   Inst->replaceAllUsesWith(NewInst);
510   Inst->eraseFromParent();
511 }
512 
513 // Figure out what state we should assign calls in this block.
514 int WinEHStatePass::getBaseStateForBB(
515     DenseMap<BasicBlock *, ColorVector> &BlockColors, WinEHFuncInfo &FuncInfo,
516     BasicBlock *BB) {
517   int BaseState = ParentBaseState;
518   auto &BBColors = BlockColors[BB];
519 
520   assert(BBColors.size() == 1 && "multi-color BB not removed by preparation");
521   BasicBlock *FuncletEntryBB = BBColors.front();
522   if (auto *FuncletPad =
523           dyn_cast<FuncletPadInst>(FuncletEntryBB->getFirstNonPHI())) {
524     auto BaseStateI = FuncInfo.FuncletBaseStateMap.find(FuncletPad);
525     if (BaseStateI != FuncInfo.FuncletBaseStateMap.end())
526       BaseState = BaseStateI->second;
527   }
528 
529   return BaseState;
530 }
531 
532 // Calculate the state a call-site is in.
533 int WinEHStatePass::getStateForCallSite(
534     DenseMap<BasicBlock *, ColorVector> &BlockColors, WinEHFuncInfo &FuncInfo,
535     CallSite CS) {
536   if (auto *II = dyn_cast<InvokeInst>(CS.getInstruction())) {
537     // Look up the state number of the EH pad this unwinds to.
538     assert(FuncInfo.InvokeStateMap.count(II) && "invoke has no state!");
539     return FuncInfo.InvokeStateMap[II];
540   }
541   // Possibly throwing call instructions have no actions to take after
542   // an unwind. Ensure they are in the -1 state.
543   return getBaseStateForBB(BlockColors, FuncInfo, CS.getParent());
544 }
545 
546 // Calculate the intersection of all the FinalStates for a BasicBlock's
547 // predecessors.
548 static int getPredState(DenseMap<BasicBlock *, int> &FinalStates, Function &F,
549                         int ParentBaseState, BasicBlock *BB) {
550   // The entry block has no predecessors but we know that the prologue always
551   // sets us up with a fixed state.
552   if (&F.getEntryBlock() == BB)
553     return ParentBaseState;
554 
555   // This is an EH Pad, conservatively report this basic block as overdefined.
556   if (BB->isEHPad())
557     return OverdefinedState;
558 
559   int CommonState = OverdefinedState;
560   for (BasicBlock *PredBB : predecessors(BB)) {
561     // We didn't manage to get a state for one of these predecessors,
562     // conservatively report this basic block as overdefined.
563     auto PredEndState = FinalStates.find(PredBB);
564     if (PredEndState == FinalStates.end())
565       return OverdefinedState;
566 
567     // This code is reachable via exceptional control flow,
568     // conservatively report this basic block as overdefined.
569     if (isa<CatchReturnInst>(PredBB->getTerminator()))
570       return OverdefinedState;
571 
572     int PredState = PredEndState->second;
573     assert(PredState != OverdefinedState &&
574            "overdefined BBs shouldn't be in FinalStates");
575     if (CommonState == OverdefinedState)
576       CommonState = PredState;
577 
578     // At least two predecessors have different FinalStates,
579     // conservatively report this basic block as overdefined.
580     if (CommonState != PredState)
581       return OverdefinedState;
582   }
583 
584   return CommonState;
585 }
586 
587 // Calculate the intersection of all the InitialStates for a BasicBlock's
588 // successors.
589 static int getSuccState(DenseMap<BasicBlock *, int> &InitialStates, Function &F,
590                         int ParentBaseState, BasicBlock *BB) {
591   // This block rejoins normal control flow,
592   // conservatively report this basic block as overdefined.
593   if (isa<CatchReturnInst>(BB->getTerminator()))
594     return OverdefinedState;
595 
596   int CommonState = OverdefinedState;
597   for (BasicBlock *SuccBB : successors(BB)) {
598     // We didn't manage to get a state for one of these predecessors,
599     // conservatively report this basic block as overdefined.
600     auto SuccStartState = InitialStates.find(SuccBB);
601     if (SuccStartState == InitialStates.end())
602       return OverdefinedState;
603 
604     // This is an EH Pad, conservatively report this basic block as overdefined.
605     if (SuccBB->isEHPad())
606       return OverdefinedState;
607 
608     int SuccState = SuccStartState->second;
609     assert(SuccState != OverdefinedState &&
610            "overdefined BBs shouldn't be in FinalStates");
611     if (CommonState == OverdefinedState)
612       CommonState = SuccState;
613 
614     // At least two successors have different InitialStates,
615     // conservatively report this basic block as overdefined.
616     if (CommonState != SuccState)
617       return OverdefinedState;
618   }
619 
620   return CommonState;
621 }
622 
623 bool WinEHStatePass::isStateStoreNeeded(EHPersonality Personality,
624                                         CallSite CS) {
625   if (!CS)
626     return false;
627 
628   // If the function touches memory, it needs a state store.
629   if (isAsynchronousEHPersonality(Personality))
630     return !CS.doesNotAccessMemory();
631 
632   // If the function throws, it needs a state store.
633   return !CS.doesNotThrow();
634 }
635 
636 void WinEHStatePass::addStateStores(Function &F, WinEHFuncInfo &FuncInfo) {
637   // Mark the registration node. The backend needs to know which alloca it is so
638   // that it can recover the original frame pointer.
639   IRBuilder<> Builder(RegNode->getNextNode());
640   Value *RegNodeI8 = Builder.CreateBitCast(RegNode, Builder.getInt8PtrTy());
641   Builder.CreateCall(
642       Intrinsic::getDeclaration(TheModule, Intrinsic::x86_seh_ehregnode),
643       {RegNodeI8});
644 
645   if (EHGuardNode) {
646     IRBuilder<> Builder(EHGuardNode->getNextNode());
647     Value *EHGuardNodeI8 =
648         Builder.CreateBitCast(EHGuardNode, Builder.getInt8PtrTy());
649     Builder.CreateCall(
650         Intrinsic::getDeclaration(TheModule, Intrinsic::x86_seh_ehguard),
651         {EHGuardNodeI8});
652   }
653 
654   // Calculate state numbers.
655   if (isAsynchronousEHPersonality(Personality))
656     calculateSEHStateNumbers(&F, FuncInfo);
657   else
658     calculateWinCXXEHStateNumbers(&F, FuncInfo);
659 
660   // Iterate all the instructions and emit state number stores.
661   DenseMap<BasicBlock *, ColorVector> BlockColors = colorEHFunclets(F);
662   ReversePostOrderTraversal<Function *> RPOT(&F);
663 
664   // InitialStates yields the state of the first call-site for a BasicBlock.
665   DenseMap<BasicBlock *, int> InitialStates;
666   // FinalStates yields the state of the last call-site for a BasicBlock.
667   DenseMap<BasicBlock *, int> FinalStates;
668   // Worklist used to revisit BasicBlocks with indeterminate
669   // Initial/Final-States.
670   std::deque<BasicBlock *> Worklist;
671   // Fill in InitialStates and FinalStates for BasicBlocks with call-sites.
672   for (BasicBlock *BB : RPOT) {
673     int InitialState = OverdefinedState;
674     int FinalState;
675     if (&F.getEntryBlock() == BB)
676       InitialState = FinalState = ParentBaseState;
677     for (Instruction &I : *BB) {
678       CallSite CS(&I);
679       if (!isStateStoreNeeded(Personality, CS))
680         continue;
681 
682       int State = getStateForCallSite(BlockColors, FuncInfo, CS);
683       if (InitialState == OverdefinedState)
684         InitialState = State;
685       FinalState = State;
686     }
687     // No call-sites in this basic block? That's OK, we will come back to these
688     // in a later pass.
689     if (InitialState == OverdefinedState) {
690       Worklist.push_back(BB);
691       continue;
692     }
693     LLVM_DEBUG(dbgs() << "X86WinEHState: " << BB->getName()
694                       << " InitialState=" << InitialState << '\n');
695     LLVM_DEBUG(dbgs() << "X86WinEHState: " << BB->getName()
696                       << " FinalState=" << FinalState << '\n');
697     InitialStates.insert({BB, InitialState});
698     FinalStates.insert({BB, FinalState});
699   }
700 
701   // Try to fill-in InitialStates and FinalStates which have no call-sites.
702   while (!Worklist.empty()) {
703     BasicBlock *BB = Worklist.front();
704     Worklist.pop_front();
705     // This BasicBlock has already been figured out, nothing more we can do.
706     if (InitialStates.count(BB) != 0)
707       continue;
708 
709     int PredState = getPredState(FinalStates, F, ParentBaseState, BB);
710     if (PredState == OverdefinedState)
711       continue;
712 
713     // We successfully inferred this BasicBlock's state via it's predecessors;
714     // enqueue it's successors to see if we can infer their states.
715     InitialStates.insert({BB, PredState});
716     FinalStates.insert({BB, PredState});
717     for (BasicBlock *SuccBB : successors(BB))
718       Worklist.push_back(SuccBB);
719   }
720 
721   // Try to hoist stores from successors.
722   for (BasicBlock *BB : RPOT) {
723     int SuccState = getSuccState(InitialStates, F, ParentBaseState, BB);
724     if (SuccState == OverdefinedState)
725       continue;
726 
727     // Update our FinalState to reflect the common InitialState of our
728     // successors.
729     FinalStates.insert({BB, SuccState});
730   }
731 
732   // Finally, insert state stores before call-sites which transition us to a new
733   // state.
734   for (BasicBlock *BB : RPOT) {
735     auto &BBColors = BlockColors[BB];
736     BasicBlock *FuncletEntryBB = BBColors.front();
737     if (isa<CleanupPadInst>(FuncletEntryBB->getFirstNonPHI()))
738       continue;
739 
740     int PrevState = getPredState(FinalStates, F, ParentBaseState, BB);
741     LLVM_DEBUG(dbgs() << "X86WinEHState: " << BB->getName()
742                       << " PrevState=" << PrevState << '\n');
743 
744     for (Instruction &I : *BB) {
745       CallSite CS(&I);
746       if (!isStateStoreNeeded(Personality, CS))
747         continue;
748 
749       int State = getStateForCallSite(BlockColors, FuncInfo, CS);
750       if (State != PrevState)
751         insertStateNumberStore(&I, State);
752       PrevState = State;
753     }
754 
755     // We might have hoisted a state store into this block, emit it now.
756     auto EndState = FinalStates.find(BB);
757     if (EndState != FinalStates.end())
758       if (EndState->second != PrevState)
759         insertStateNumberStore(BB->getTerminator(), EndState->second);
760   }
761 
762   SmallVector<CallSite, 1> SetJmp3CallSites;
763   for (BasicBlock *BB : RPOT) {
764     for (Instruction &I : *BB) {
765       CallSite CS(&I);
766       if (!CS)
767         continue;
768       if (CS.getCalledValue()->stripPointerCasts() !=
769           SetJmp3.getCallee()->stripPointerCasts())
770         continue;
771 
772       SetJmp3CallSites.push_back(CS);
773     }
774   }
775 
776   for (CallSite CS : SetJmp3CallSites) {
777     auto &BBColors = BlockColors[CS->getParent()];
778     BasicBlock *FuncletEntryBB = BBColors.front();
779     bool InCleanup = isa<CleanupPadInst>(FuncletEntryBB->getFirstNonPHI());
780 
781     IRBuilder<> Builder(CS.getInstruction());
782     Value *State;
783     if (InCleanup) {
784       Value *StateField =
785           Builder.CreateStructGEP(nullptr, RegNode, StateFieldIndex);
786       State = Builder.CreateLoad(Builder.getInt32Ty(), StateField);
787     } else {
788       State = Builder.getInt32(getStateForCallSite(BlockColors, FuncInfo, CS));
789     }
790     rewriteSetJmpCallSite(Builder, F, CS, State);
791   }
792 }
793 
794 void WinEHStatePass::insertStateNumberStore(Instruction *IP, int State) {
795   IRBuilder<> Builder(IP);
796   Value *StateField =
797       Builder.CreateStructGEP(nullptr, RegNode, StateFieldIndex);
798   Builder.CreateStore(Builder.getInt32(State), StateField);
799 }
800