xref: /llvm-project/llvm/lib/SandboxIR/Context.cpp (revision 79cbad188afd5268235b00267d37ce39544dbd3c)
1 //===- Context.cpp - The Context class of Sandbox IR ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/SandboxIR/Context.h"
10 #include "llvm/SandboxIR/Function.h"
11 #include "llvm/SandboxIR/Instruction.h"
12 #include "llvm/SandboxIR/Module.h"
13 
14 namespace llvm::sandboxir {
15 
16 std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
17   std::unique_ptr<Value> Erased;
18   auto It = LLVMValueToValueMap.find(V);
19   if (It != LLVMValueToValueMap.end()) {
20     auto *Val = It->second.release();
21     Erased = std::unique_ptr<Value>(Val);
22     LLVMValueToValueMap.erase(It);
23   }
24   return Erased;
25 }
26 
27 std::unique_ptr<Value> Context::detach(Value *V) {
28   assert(V->getSubclassID() != Value::ClassID::Constant &&
29          "Can't detach a constant!");
30   assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
31   return detachLLVMValue(V->Val);
32 }
33 
34 Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
35   assert(VPtr->getSubclassID() != Value::ClassID::User &&
36          "Can't register a user!");
37 
38   Value *V = VPtr.get();
39   [[maybe_unused]] auto Pair =
40       LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
41   assert(Pair.second && "Already exists!");
42 
43   // Track creation of instructions.
44   // Please note that we don't allow the creation of detached instructions,
45   // meaning that the instructions need to be inserted into a block upon
46   // creation. This is why the tracker class combines creation and insertion.
47   if (auto *I = dyn_cast<Instruction>(V)) {
48     getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
49     runCreateInstrCallbacks(I);
50   }
51 
52   return V;
53 }
54 
55 Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
56   auto Pair = LLVMValueToValueMap.insert({LLVMV, nullptr});
57   auto It = Pair.first;
58   if (!Pair.second)
59     return It->second.get();
60 
61   if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
62     switch (C->getValueID()) {
63     case llvm::Value::ConstantIntVal:
64       It->second = std::unique_ptr<ConstantInt>(
65           new ConstantInt(cast<llvm::ConstantInt>(C), *this));
66       return It->second.get();
67     case llvm::Value::ConstantFPVal:
68       It->second = std::unique_ptr<ConstantFP>(
69           new ConstantFP(cast<llvm::ConstantFP>(C), *this));
70       return It->second.get();
71     case llvm::Value::BlockAddressVal:
72       It->second = std::unique_ptr<BlockAddress>(
73           new BlockAddress(cast<llvm::BlockAddress>(C), *this));
74       return It->second.get();
75     case llvm::Value::ConstantTokenNoneVal:
76       It->second = std::unique_ptr<ConstantTokenNone>(
77           new ConstantTokenNone(cast<llvm::ConstantTokenNone>(C), *this));
78       return It->second.get();
79     case llvm::Value::ConstantAggregateZeroVal: {
80       auto *CAZ = cast<llvm::ConstantAggregateZero>(C);
81       It->second = std::unique_ptr<ConstantAggregateZero>(
82           new ConstantAggregateZero(CAZ, *this));
83       auto *Ret = It->second.get();
84       // Must create sandboxir for elements.
85       auto EC = CAZ->getElementCount();
86       if (EC.isFixed()) {
87         for (auto ElmIdx : seq<unsigned>(0, EC.getFixedValue()))
88           getOrCreateValueInternal(CAZ->getElementValue(ElmIdx), CAZ);
89       }
90       return Ret;
91     }
92     case llvm::Value::ConstantPointerNullVal:
93       It->second = std::unique_ptr<ConstantPointerNull>(
94           new ConstantPointerNull(cast<llvm::ConstantPointerNull>(C), *this));
95       return It->second.get();
96     case llvm::Value::PoisonValueVal:
97       It->second = std::unique_ptr<PoisonValue>(
98           new PoisonValue(cast<llvm::PoisonValue>(C), *this));
99       return It->second.get();
100     case llvm::Value::UndefValueVal:
101       It->second = std::unique_ptr<UndefValue>(
102           new UndefValue(cast<llvm::UndefValue>(C), *this));
103       return It->second.get();
104     case llvm::Value::DSOLocalEquivalentVal: {
105       auto *DSOLE = cast<llvm::DSOLocalEquivalent>(C);
106       It->second = std::unique_ptr<DSOLocalEquivalent>(
107           new DSOLocalEquivalent(DSOLE, *this));
108       auto *Ret = It->second.get();
109       getOrCreateValueInternal(DSOLE->getGlobalValue(), DSOLE);
110       return Ret;
111     }
112     case llvm::Value::ConstantArrayVal:
113       It->second = std::unique_ptr<ConstantArray>(
114           new ConstantArray(cast<llvm::ConstantArray>(C), *this));
115       break;
116     case llvm::Value::ConstantStructVal:
117       It->second = std::unique_ptr<ConstantStruct>(
118           new ConstantStruct(cast<llvm::ConstantStruct>(C), *this));
119       break;
120     case llvm::Value::ConstantVectorVal:
121       It->second = std::unique_ptr<ConstantVector>(
122           new ConstantVector(cast<llvm::ConstantVector>(C), *this));
123       break;
124     case llvm::Value::FunctionVal:
125       It->second = std::unique_ptr<Function>(
126           new Function(cast<llvm::Function>(C), *this));
127       break;
128     case llvm::Value::GlobalIFuncVal:
129       It->second = std::unique_ptr<GlobalIFunc>(
130           new GlobalIFunc(cast<llvm::GlobalIFunc>(C), *this));
131       break;
132     case llvm::Value::GlobalVariableVal:
133       It->second = std::unique_ptr<GlobalVariable>(
134           new GlobalVariable(cast<llvm::GlobalVariable>(C), *this));
135       break;
136     case llvm::Value::GlobalAliasVal:
137       It->second = std::unique_ptr<GlobalAlias>(
138           new GlobalAlias(cast<llvm::GlobalAlias>(C), *this));
139       break;
140     case llvm::Value::NoCFIValueVal:
141       It->second = std::unique_ptr<NoCFIValue>(
142           new NoCFIValue(cast<llvm::NoCFIValue>(C), *this));
143       break;
144     case llvm::Value::ConstantPtrAuthVal:
145       It->second = std::unique_ptr<ConstantPtrAuth>(
146           new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(C), *this));
147       break;
148     case llvm::Value::ConstantExprVal:
149       It->second = std::unique_ptr<ConstantExpr>(
150           new ConstantExpr(cast<llvm::ConstantExpr>(C), *this));
151       break;
152     default:
153       It->second = std::unique_ptr<Constant>(new Constant(C, *this));
154       break;
155     }
156     auto *NewC = It->second.get();
157     for (llvm::Value *COp : C->operands())
158       getOrCreateValueInternal(COp, C);
159     return NewC;
160   }
161   if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
162     It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
163     return It->second.get();
164   }
165   if (auto *BB = dyn_cast<llvm::BasicBlock>(LLVMV)) {
166     assert(isa<llvm::BlockAddress>(U) &&
167            "This won't create a SBBB, don't call this function directly!");
168     if (auto *SBBB = getValue(BB))
169       return SBBB;
170     return nullptr;
171   }
172   assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");
173 
174   switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
175   case llvm::Instruction::VAArg: {
176     auto *LLVMVAArg = cast<llvm::VAArgInst>(LLVMV);
177     It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this));
178     return It->second.get();
179   }
180   case llvm::Instruction::Freeze: {
181     auto *LLVMFreeze = cast<llvm::FreezeInst>(LLVMV);
182     It->second = std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this));
183     return It->second.get();
184   }
185   case llvm::Instruction::Fence: {
186     auto *LLVMFence = cast<llvm::FenceInst>(LLVMV);
187     It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this));
188     return It->second.get();
189   }
190   case llvm::Instruction::Select: {
191     auto *LLVMSel = cast<llvm::SelectInst>(LLVMV);
192     It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
193     return It->second.get();
194   }
195   case llvm::Instruction::ExtractElement: {
196     auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV);
197     It->second = std::unique_ptr<ExtractElementInst>(
198         new ExtractElementInst(LLVMIns, *this));
199     return It->second.get();
200   }
201   case llvm::Instruction::InsertElement: {
202     auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
203     It->second = std::unique_ptr<InsertElementInst>(
204         new InsertElementInst(LLVMIns, *this));
205     return It->second.get();
206   }
207   case llvm::Instruction::ShuffleVector: {
208     auto *LLVMIns = cast<llvm::ShuffleVectorInst>(LLVMV);
209     It->second = std::unique_ptr<ShuffleVectorInst>(
210         new ShuffleVectorInst(LLVMIns, *this));
211     return It->second.get();
212   }
213   case llvm::Instruction::ExtractValue: {
214     auto *LLVMIns = cast<llvm::ExtractValueInst>(LLVMV);
215     It->second =
216         std::unique_ptr<ExtractValueInst>(new ExtractValueInst(LLVMIns, *this));
217     return It->second.get();
218   }
219   case llvm::Instruction::InsertValue: {
220     auto *LLVMIns = cast<llvm::InsertValueInst>(LLVMV);
221     It->second =
222         std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
223     return It->second.get();
224   }
225   case llvm::Instruction::Br: {
226     auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
227     It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
228     return It->second.get();
229   }
230   case llvm::Instruction::Load: {
231     auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
232     It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
233     return It->second.get();
234   }
235   case llvm::Instruction::Store: {
236     auto *LLVMSt = cast<llvm::StoreInst>(LLVMV);
237     It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
238     return It->second.get();
239   }
240   case llvm::Instruction::Ret: {
241     auto *LLVMRet = cast<llvm::ReturnInst>(LLVMV);
242     It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
243     return It->second.get();
244   }
245   case llvm::Instruction::Call: {
246     auto *LLVMCall = cast<llvm::CallInst>(LLVMV);
247     It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
248     return It->second.get();
249   }
250   case llvm::Instruction::Invoke: {
251     auto *LLVMInvoke = cast<llvm::InvokeInst>(LLVMV);
252     It->second = std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
253     return It->second.get();
254   }
255   case llvm::Instruction::CallBr: {
256     auto *LLVMCallBr = cast<llvm::CallBrInst>(LLVMV);
257     It->second = std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
258     return It->second.get();
259   }
260   case llvm::Instruction::LandingPad: {
261     auto *LLVMLPad = cast<llvm::LandingPadInst>(LLVMV);
262     It->second =
263         std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this));
264     return It->second.get();
265   }
266   case llvm::Instruction::CatchPad: {
267     auto *LLVMCPI = cast<llvm::CatchPadInst>(LLVMV);
268     It->second =
269         std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
270     return It->second.get();
271   }
272   case llvm::Instruction::CleanupPad: {
273     auto *LLVMCPI = cast<llvm::CleanupPadInst>(LLVMV);
274     It->second =
275         std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
276     return It->second.get();
277   }
278   case llvm::Instruction::CatchRet: {
279     auto *LLVMCRI = cast<llvm::CatchReturnInst>(LLVMV);
280     It->second =
281         std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
282     return It->second.get();
283   }
284   case llvm::Instruction::CleanupRet: {
285     auto *LLVMCRI = cast<llvm::CleanupReturnInst>(LLVMV);
286     It->second = std::unique_ptr<CleanupReturnInst>(
287         new CleanupReturnInst(LLVMCRI, *this));
288     return It->second.get();
289   }
290   case llvm::Instruction::GetElementPtr: {
291     auto *LLVMGEP = cast<llvm::GetElementPtrInst>(LLVMV);
292     It->second = std::unique_ptr<GetElementPtrInst>(
293         new GetElementPtrInst(LLVMGEP, *this));
294     return It->second.get();
295   }
296   case llvm::Instruction::CatchSwitch: {
297     auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(LLVMV);
298     It->second = std::unique_ptr<CatchSwitchInst>(
299         new CatchSwitchInst(LLVMCatchSwitchInst, *this));
300     return It->second.get();
301   }
302   case llvm::Instruction::Resume: {
303     auto *LLVMResumeInst = cast<llvm::ResumeInst>(LLVMV);
304     It->second =
305         std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this));
306     return It->second.get();
307   }
308   case llvm::Instruction::Switch: {
309     auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
310     It->second =
311         std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
312     return It->second.get();
313   }
314   case llvm::Instruction::FNeg: {
315     auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
316     It->second = std::unique_ptr<UnaryOperator>(
317         new UnaryOperator(LLVMUnaryOperator, *this));
318     return It->second.get();
319   }
320   case llvm::Instruction::Add:
321   case llvm::Instruction::FAdd:
322   case llvm::Instruction::Sub:
323   case llvm::Instruction::FSub:
324   case llvm::Instruction::Mul:
325   case llvm::Instruction::FMul:
326   case llvm::Instruction::UDiv:
327   case llvm::Instruction::SDiv:
328   case llvm::Instruction::FDiv:
329   case llvm::Instruction::URem:
330   case llvm::Instruction::SRem:
331   case llvm::Instruction::FRem:
332   case llvm::Instruction::Shl:
333   case llvm::Instruction::LShr:
334   case llvm::Instruction::AShr:
335   case llvm::Instruction::And:
336   case llvm::Instruction::Or:
337   case llvm::Instruction::Xor: {
338     auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(LLVMV);
339     It->second = std::unique_ptr<BinaryOperator>(
340         new BinaryOperator(LLVMBinaryOperator, *this));
341     return It->second.get();
342   }
343   case llvm::Instruction::AtomicRMW: {
344     auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(LLVMV);
345     It->second =
346         std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(LLVMAtomicRMW, *this));
347     return It->second.get();
348   }
349   case llvm::Instruction::AtomicCmpXchg: {
350     auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
351     It->second = std::unique_ptr<AtomicCmpXchgInst>(
352         new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
353     return It->second.get();
354   }
355   case llvm::Instruction::Alloca: {
356     auto *LLVMAlloca = cast<llvm::AllocaInst>(LLVMV);
357     It->second = std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
358     return It->second.get();
359   }
360   case llvm::Instruction::ZExt:
361   case llvm::Instruction::SExt:
362   case llvm::Instruction::FPToUI:
363   case llvm::Instruction::FPToSI:
364   case llvm::Instruction::FPExt:
365   case llvm::Instruction::PtrToInt:
366   case llvm::Instruction::IntToPtr:
367   case llvm::Instruction::SIToFP:
368   case llvm::Instruction::UIToFP:
369   case llvm::Instruction::Trunc:
370   case llvm::Instruction::FPTrunc:
371   case llvm::Instruction::BitCast:
372   case llvm::Instruction::AddrSpaceCast: {
373     auto *LLVMCast = cast<llvm::CastInst>(LLVMV);
374     It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
375     return It->second.get();
376   }
377   case llvm::Instruction::PHI: {
378     auto *LLVMPhi = cast<llvm::PHINode>(LLVMV);
379     It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
380     return It->second.get();
381   }
382   case llvm::Instruction::ICmp: {
383     auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV);
384     It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
385     return It->second.get();
386   }
387   case llvm::Instruction::FCmp: {
388     auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV);
389     It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
390     return It->second.get();
391   }
392   case llvm::Instruction::Unreachable: {
393     auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV);
394     It->second = std::unique_ptr<UnreachableInst>(
395         new UnreachableInst(LLVMUnreachable, *this));
396     return It->second.get();
397   }
398   default:
399     break;
400   }
401 
402   It->second = std::unique_ptr<OpaqueInst>(
403       new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
404   return It->second.get();
405 }
406 
407 Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) {
408   auto Pair = LLVMValueToValueMap.insert({LLVMArg, nullptr});
409   auto It = Pair.first;
410   if (Pair.second) {
411     It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
412     return cast<Argument>(It->second.get());
413   }
414   return cast<Argument>(It->second.get());
415 }
416 
417 Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) {
418   return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
419 }
420 
421 BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
422   assert(getValue(LLVMBB) == nullptr && "Already exists!");
423   auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
424   auto *BB = cast<BasicBlock>(registerValue(std::move(NewBBPtr)));
425   // Create SandboxIR for BB's body.
426   BB->buildBasicBlockFromLLVMIR(LLVMBB);
427   return BB;
428 }
429 
430 VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) {
431   auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this));
432   return cast<VAArgInst>(registerValue(std::move(NewPtr)));
433 }
434 
435 FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) {
436   auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this));
437   return cast<FreezeInst>(registerValue(std::move(NewPtr)));
438 }
439 
440 FenceInst *Context::createFenceInst(llvm::FenceInst *SI) {
441   auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this));
442   return cast<FenceInst>(registerValue(std::move(NewPtr)));
443 }
444 
445 SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
446   auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
447   return cast<SelectInst>(registerValue(std::move(NewPtr)));
448 }
449 
450 ExtractElementInst *
451 Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
452   auto NewPtr =
453       std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
454   return cast<ExtractElementInst>(registerValue(std::move(NewPtr)));
455 }
456 
457 InsertElementInst *
458 Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
459   auto NewPtr =
460       std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
461   return cast<InsertElementInst>(registerValue(std::move(NewPtr)));
462 }
463 
464 ShuffleVectorInst *
465 Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
466   auto NewPtr =
467       std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
468   return cast<ShuffleVectorInst>(registerValue(std::move(NewPtr)));
469 }
470 
471 ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
472   auto NewPtr =
473       std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
474   return cast<ExtractValueInst>(registerValue(std::move(NewPtr)));
475 }
476 
477 InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
478   auto NewPtr =
479       std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
480   return cast<InsertValueInst>(registerValue(std::move(NewPtr)));
481 }
482 
483 BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
484   auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
485   return cast<BranchInst>(registerValue(std::move(NewPtr)));
486 }
487 
488 LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
489   auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
490   return cast<LoadInst>(registerValue(std::move(NewPtr)));
491 }
492 
493 StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
494   auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
495   return cast<StoreInst>(registerValue(std::move(NewPtr)));
496 }
497 
498 ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
499   auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
500   return cast<ReturnInst>(registerValue(std::move(NewPtr)));
501 }
502 
503 CallInst *Context::createCallInst(llvm::CallInst *I) {
504   auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
505   return cast<CallInst>(registerValue(std::move(NewPtr)));
506 }
507 
508 InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
509   auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
510   return cast<InvokeInst>(registerValue(std::move(NewPtr)));
511 }
512 
513 CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
514   auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
515   return cast<CallBrInst>(registerValue(std::move(NewPtr)));
516 }
517 
518 UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
519   auto NewPtr =
520       std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
521   return cast<UnreachableInst>(registerValue(std::move(NewPtr)));
522 }
523 LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) {
524   auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this));
525   return cast<LandingPadInst>(registerValue(std::move(NewPtr)));
526 }
527 CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
528   auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
529   return cast<CatchPadInst>(registerValue(std::move(NewPtr)));
530 }
531 CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
532   auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
533   return cast<CleanupPadInst>(registerValue(std::move(NewPtr)));
534 }
535 CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
536   auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
537   return cast<CatchReturnInst>(registerValue(std::move(NewPtr)));
538 }
539 CleanupReturnInst *
540 Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) {
541   auto NewPtr =
542       std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this));
543   return cast<CleanupReturnInst>(registerValue(std::move(NewPtr)));
544 }
545 GetElementPtrInst *
546 Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
547   auto NewPtr =
548       std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
549   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
550 }
551 CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
552   auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
553   return cast<CatchSwitchInst>(registerValue(std::move(NewPtr)));
554 }
555 ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) {
556   auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this));
557   return cast<ResumeInst>(registerValue(std::move(NewPtr)));
558 }
559 SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
560   auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
561   return cast<SwitchInst>(registerValue(std::move(NewPtr)));
562 }
563 UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
564   auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
565   return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
566 }
567 BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
568   auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
569   return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
570 }
571 AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
572   auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
573   return cast<AtomicRMWInst>(registerValue(std::move(NewPtr)));
574 }
575 AtomicCmpXchgInst *
576 Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
577   auto NewPtr =
578       std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
579   return cast<AtomicCmpXchgInst>(registerValue(std::move(NewPtr)));
580 }
581 AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
582   auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
583   return cast<AllocaInst>(registerValue(std::move(NewPtr)));
584 }
585 CastInst *Context::createCastInst(llvm::CastInst *I) {
586   auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
587   return cast<CastInst>(registerValue(std::move(NewPtr)));
588 }
589 PHINode *Context::createPHINode(llvm::PHINode *I) {
590   auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
591   return cast<PHINode>(registerValue(std::move(NewPtr)));
592 }
593 ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
594   auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
595   return cast<ICmpInst>(registerValue(std::move(NewPtr)));
596 }
597 FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
598   auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
599   return cast<FCmpInst>(registerValue(std::move(NewPtr)));
600 }
601 Value *Context::getValue(llvm::Value *V) const {
602   auto It = LLVMValueToValueMap.find(V);
603   if (It != LLVMValueToValueMap.end())
604     return It->second.get();
605   return nullptr;
606 }
607 
608 Context::Context(LLVMContext &LLVMCtx)
609     : LLVMCtx(LLVMCtx), IRTracker(*this),
610       LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
611 
612 Context::~Context() {}
613 
614 void Context::clear() {
615   // TODO: Ideally we should clear only function-scope objects, and keep global
616   // objects, like Constants to avoid recreating them.
617   LLVMValueToValueMap.clear();
618 }
619 
620 Module *Context::getModule(llvm::Module *LLVMM) const {
621   auto It = LLVMModuleToModuleMap.find(LLVMM);
622   if (It != LLVMModuleToModuleMap.end())
623     return It->second.get();
624   return nullptr;
625 }
626 
627 Module *Context::getOrCreateModule(llvm::Module *LLVMM) {
628   auto Pair = LLVMModuleToModuleMap.insert({LLVMM, nullptr});
629   auto It = Pair.first;
630   if (!Pair.second)
631     return It->second.get();
632   It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this));
633   return It->second.get();
634 }
635 
636 Function *Context::createFunction(llvm::Function *F) {
637   // Create the module if needed before we create the new sandboxir::Function.
638   // Note: this won't fully populate the module. The only globals that will be
639   // available will be the ones being used within the function.
640   getOrCreateModule(F->getParent());
641 
642   // There may be a function declaration already defined. Regardless destroy it.
643   if (Function *ExistingF = cast_or_null<Function>(getValue(F)))
644     detach(ExistingF);
645 
646   auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
647   auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
648   // Create arguments.
649   for (auto &Arg : F->args())
650     getOrCreateArgument(&Arg);
651   // Create BBs.
652   for (auto &BB : *F)
653     createBasicBlock(&BB);
654   return SBF;
655 }
656 
657 Module *Context::createModule(llvm::Module *LLVMM) {
658   auto *M = getOrCreateModule(LLVMM);
659   // Create the functions.
660   for (auto &LLVMF : *LLVMM)
661     createFunction(&LLVMF);
662   // Create globals.
663   for (auto &Global : LLVMM->globals())
664     getOrCreateValue(&Global);
665   // Create aliases.
666   for (auto &Alias : LLVMM->aliases())
667     getOrCreateValue(&Alias);
668   // Create ifuncs.
669   for (auto &IFunc : LLVMM->ifuncs())
670     getOrCreateValue(&IFunc);
671 
672   return M;
673 }
674 
675 void Context::runEraseInstrCallbacks(Instruction *I) {
676   for (const auto &CBEntry : EraseInstrCallbacks)
677     CBEntry.second(I);
678 }
679 
680 void Context::runCreateInstrCallbacks(Instruction *I) {
681   for (auto &CBEntry : CreateInstrCallbacks)
682     CBEntry.second(I);
683 }
684 
685 void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
686   for (auto &CBEntry : MoveInstrCallbacks)
687     CBEntry.second(I, WhereIt);
688 }
689 
690 // An arbitrary limit, to check for accidental misuse. We expect a small number
691 // of callbacks to be registered at a time, but we can increase this number if
692 // we discover we needed more.
693 [[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16;
694 
695 Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
696   assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
697          "EraseInstrCallbacks size limit exceeded");
698   CallbackID ID{NextCallbackID++};
699   EraseInstrCallbacks[ID] = CB;
700   return ID;
701 }
702 void Context::unregisterEraseInstrCallback(CallbackID ID) {
703   [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID);
704   assert(Erased &&
705          "Callback ID not found in EraseInstrCallbacks during deregistration");
706 }
707 
708 Context::CallbackID
709 Context::registerCreateInstrCallback(CreateInstrCallback CB) {
710   assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
711          "CreateInstrCallbacks size limit exceeded");
712   CallbackID ID{NextCallbackID++};
713   CreateInstrCallbacks[ID] = CB;
714   return ID;
715 }
716 void Context::unregisterCreateInstrCallback(CallbackID ID) {
717   [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID);
718   assert(Erased &&
719          "Callback ID not found in CreateInstrCallbacks during deregistration");
720 }
721 
722 Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
723   assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
724          "MoveInstrCallbacks size limit exceeded");
725   CallbackID ID{NextCallbackID++};
726   MoveInstrCallbacks[ID] = CB;
727   return ID;
728 }
729 void Context::unregisterMoveInstrCallback(CallbackID ID) {
730   [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
731   assert(Erased &&
732          "Callback ID not found in MoveInstrCallbacks during deregistration");
733 }
734 
735 } // namespace llvm::sandboxir
736