1 //===- Context.cpp - The Context class of Sandbox IR ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/SandboxIR/Context.h" 10 #include "llvm/SandboxIR/Function.h" 11 #include "llvm/SandboxIR/Instruction.h" 12 #include "llvm/SandboxIR/Module.h" 13 14 namespace llvm::sandboxir { 15 16 std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) { 17 std::unique_ptr<Value> Erased; 18 auto It = LLVMValueToValueMap.find(V); 19 if (It != LLVMValueToValueMap.end()) { 20 auto *Val = It->second.release(); 21 Erased = std::unique_ptr<Value>(Val); 22 LLVMValueToValueMap.erase(It); 23 } 24 return Erased; 25 } 26 27 std::unique_ptr<Value> Context::detach(Value *V) { 28 assert(V->getSubclassID() != Value::ClassID::Constant && 29 "Can't detach a constant!"); 30 assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!"); 31 return detachLLVMValue(V->Val); 32 } 33 34 Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) { 35 assert(VPtr->getSubclassID() != Value::ClassID::User && 36 "Can't register a user!"); 37 38 Value *V = VPtr.get(); 39 [[maybe_unused]] auto Pair = 40 LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)}); 41 assert(Pair.second && "Already exists!"); 42 43 // Track creation of instructions. 44 // Please note that we don't allow the creation of detached instructions, 45 // meaning that the instructions need to be inserted into a block upon 46 // creation. This is why the tracker class combines creation and insertion. 47 if (auto *I = dyn_cast<Instruction>(V)) { 48 getTracker().emplaceIfTracking<CreateAndInsertInst>(I); 49 runCreateInstrCallbacks(I); 50 } 51 52 return V; 53 } 54 55 Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { 56 auto Pair = LLVMValueToValueMap.insert({LLVMV, nullptr}); 57 auto It = Pair.first; 58 if (!Pair.second) 59 return It->second.get(); 60 61 if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) { 62 switch (C->getValueID()) { 63 case llvm::Value::ConstantIntVal: 64 It->second = std::unique_ptr<ConstantInt>( 65 new ConstantInt(cast<llvm::ConstantInt>(C), *this)); 66 return It->second.get(); 67 case llvm::Value::ConstantFPVal: 68 It->second = std::unique_ptr<ConstantFP>( 69 new ConstantFP(cast<llvm::ConstantFP>(C), *this)); 70 return It->second.get(); 71 case llvm::Value::BlockAddressVal: 72 It->second = std::unique_ptr<BlockAddress>( 73 new BlockAddress(cast<llvm::BlockAddress>(C), *this)); 74 return It->second.get(); 75 case llvm::Value::ConstantTokenNoneVal: 76 It->second = std::unique_ptr<ConstantTokenNone>( 77 new ConstantTokenNone(cast<llvm::ConstantTokenNone>(C), *this)); 78 return It->second.get(); 79 case llvm::Value::ConstantAggregateZeroVal: { 80 auto *CAZ = cast<llvm::ConstantAggregateZero>(C); 81 It->second = std::unique_ptr<ConstantAggregateZero>( 82 new ConstantAggregateZero(CAZ, *this)); 83 auto *Ret = It->second.get(); 84 // Must create sandboxir for elements. 85 auto EC = CAZ->getElementCount(); 86 if (EC.isFixed()) { 87 for (auto ElmIdx : seq<unsigned>(0, EC.getFixedValue())) 88 getOrCreateValueInternal(CAZ->getElementValue(ElmIdx), CAZ); 89 } 90 return Ret; 91 } 92 case llvm::Value::ConstantPointerNullVal: 93 It->second = std::unique_ptr<ConstantPointerNull>( 94 new ConstantPointerNull(cast<llvm::ConstantPointerNull>(C), *this)); 95 return It->second.get(); 96 case llvm::Value::PoisonValueVal: 97 It->second = std::unique_ptr<PoisonValue>( 98 new PoisonValue(cast<llvm::PoisonValue>(C), *this)); 99 return It->second.get(); 100 case llvm::Value::UndefValueVal: 101 It->second = std::unique_ptr<UndefValue>( 102 new UndefValue(cast<llvm::UndefValue>(C), *this)); 103 return It->second.get(); 104 case llvm::Value::DSOLocalEquivalentVal: { 105 auto *DSOLE = cast<llvm::DSOLocalEquivalent>(C); 106 It->second = std::unique_ptr<DSOLocalEquivalent>( 107 new DSOLocalEquivalent(DSOLE, *this)); 108 auto *Ret = It->second.get(); 109 getOrCreateValueInternal(DSOLE->getGlobalValue(), DSOLE); 110 return Ret; 111 } 112 case llvm::Value::ConstantArrayVal: 113 It->second = std::unique_ptr<ConstantArray>( 114 new ConstantArray(cast<llvm::ConstantArray>(C), *this)); 115 break; 116 case llvm::Value::ConstantStructVal: 117 It->second = std::unique_ptr<ConstantStruct>( 118 new ConstantStruct(cast<llvm::ConstantStruct>(C), *this)); 119 break; 120 case llvm::Value::ConstantVectorVal: 121 It->second = std::unique_ptr<ConstantVector>( 122 new ConstantVector(cast<llvm::ConstantVector>(C), *this)); 123 break; 124 case llvm::Value::FunctionVal: 125 It->second = std::unique_ptr<Function>( 126 new Function(cast<llvm::Function>(C), *this)); 127 break; 128 case llvm::Value::GlobalIFuncVal: 129 It->second = std::unique_ptr<GlobalIFunc>( 130 new GlobalIFunc(cast<llvm::GlobalIFunc>(C), *this)); 131 break; 132 case llvm::Value::GlobalVariableVal: 133 It->second = std::unique_ptr<GlobalVariable>( 134 new GlobalVariable(cast<llvm::GlobalVariable>(C), *this)); 135 break; 136 case llvm::Value::GlobalAliasVal: 137 It->second = std::unique_ptr<GlobalAlias>( 138 new GlobalAlias(cast<llvm::GlobalAlias>(C), *this)); 139 break; 140 case llvm::Value::NoCFIValueVal: 141 It->second = std::unique_ptr<NoCFIValue>( 142 new NoCFIValue(cast<llvm::NoCFIValue>(C), *this)); 143 break; 144 case llvm::Value::ConstantPtrAuthVal: 145 It->second = std::unique_ptr<ConstantPtrAuth>( 146 new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(C), *this)); 147 break; 148 case llvm::Value::ConstantExprVal: 149 It->second = std::unique_ptr<ConstantExpr>( 150 new ConstantExpr(cast<llvm::ConstantExpr>(C), *this)); 151 break; 152 default: 153 It->second = std::unique_ptr<Constant>(new Constant(C, *this)); 154 break; 155 } 156 auto *NewC = It->second.get(); 157 for (llvm::Value *COp : C->operands()) 158 getOrCreateValueInternal(COp, C); 159 return NewC; 160 } 161 if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) { 162 It->second = std::unique_ptr<Argument>(new Argument(Arg, *this)); 163 return It->second.get(); 164 } 165 if (auto *BB = dyn_cast<llvm::BasicBlock>(LLVMV)) { 166 assert(isa<llvm::BlockAddress>(U) && 167 "This won't create a SBBB, don't call this function directly!"); 168 if (auto *SBBB = getValue(BB)) 169 return SBBB; 170 return nullptr; 171 } 172 assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction"); 173 174 switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) { 175 case llvm::Instruction::VAArg: { 176 auto *LLVMVAArg = cast<llvm::VAArgInst>(LLVMV); 177 It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this)); 178 return It->second.get(); 179 } 180 case llvm::Instruction::Freeze: { 181 auto *LLVMFreeze = cast<llvm::FreezeInst>(LLVMV); 182 It->second = std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this)); 183 return It->second.get(); 184 } 185 case llvm::Instruction::Fence: { 186 auto *LLVMFence = cast<llvm::FenceInst>(LLVMV); 187 It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this)); 188 return It->second.get(); 189 } 190 case llvm::Instruction::Select: { 191 auto *LLVMSel = cast<llvm::SelectInst>(LLVMV); 192 It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this)); 193 return It->second.get(); 194 } 195 case llvm::Instruction::ExtractElement: { 196 auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV); 197 It->second = std::unique_ptr<ExtractElementInst>( 198 new ExtractElementInst(LLVMIns, *this)); 199 return It->second.get(); 200 } 201 case llvm::Instruction::InsertElement: { 202 auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV); 203 It->second = std::unique_ptr<InsertElementInst>( 204 new InsertElementInst(LLVMIns, *this)); 205 return It->second.get(); 206 } 207 case llvm::Instruction::ShuffleVector: { 208 auto *LLVMIns = cast<llvm::ShuffleVectorInst>(LLVMV); 209 It->second = std::unique_ptr<ShuffleVectorInst>( 210 new ShuffleVectorInst(LLVMIns, *this)); 211 return It->second.get(); 212 } 213 case llvm::Instruction::ExtractValue: { 214 auto *LLVMIns = cast<llvm::ExtractValueInst>(LLVMV); 215 It->second = 216 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(LLVMIns, *this)); 217 return It->second.get(); 218 } 219 case llvm::Instruction::InsertValue: { 220 auto *LLVMIns = cast<llvm::InsertValueInst>(LLVMV); 221 It->second = 222 std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this)); 223 return It->second.get(); 224 } 225 case llvm::Instruction::Br: { 226 auto *LLVMBr = cast<llvm::BranchInst>(LLVMV); 227 It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this)); 228 return It->second.get(); 229 } 230 case llvm::Instruction::Load: { 231 auto *LLVMLd = cast<llvm::LoadInst>(LLVMV); 232 It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this)); 233 return It->second.get(); 234 } 235 case llvm::Instruction::Store: { 236 auto *LLVMSt = cast<llvm::StoreInst>(LLVMV); 237 It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this)); 238 return It->second.get(); 239 } 240 case llvm::Instruction::Ret: { 241 auto *LLVMRet = cast<llvm::ReturnInst>(LLVMV); 242 It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this)); 243 return It->second.get(); 244 } 245 case llvm::Instruction::Call: { 246 auto *LLVMCall = cast<llvm::CallInst>(LLVMV); 247 It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this)); 248 return It->second.get(); 249 } 250 case llvm::Instruction::Invoke: { 251 auto *LLVMInvoke = cast<llvm::InvokeInst>(LLVMV); 252 It->second = std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this)); 253 return It->second.get(); 254 } 255 case llvm::Instruction::CallBr: { 256 auto *LLVMCallBr = cast<llvm::CallBrInst>(LLVMV); 257 It->second = std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this)); 258 return It->second.get(); 259 } 260 case llvm::Instruction::LandingPad: { 261 auto *LLVMLPad = cast<llvm::LandingPadInst>(LLVMV); 262 It->second = 263 std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this)); 264 return It->second.get(); 265 } 266 case llvm::Instruction::CatchPad: { 267 auto *LLVMCPI = cast<llvm::CatchPadInst>(LLVMV); 268 It->second = 269 std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this)); 270 return It->second.get(); 271 } 272 case llvm::Instruction::CleanupPad: { 273 auto *LLVMCPI = cast<llvm::CleanupPadInst>(LLVMV); 274 It->second = 275 std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this)); 276 return It->second.get(); 277 } 278 case llvm::Instruction::CatchRet: { 279 auto *LLVMCRI = cast<llvm::CatchReturnInst>(LLVMV); 280 It->second = 281 std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this)); 282 return It->second.get(); 283 } 284 case llvm::Instruction::CleanupRet: { 285 auto *LLVMCRI = cast<llvm::CleanupReturnInst>(LLVMV); 286 It->second = std::unique_ptr<CleanupReturnInst>( 287 new CleanupReturnInst(LLVMCRI, *this)); 288 return It->second.get(); 289 } 290 case llvm::Instruction::GetElementPtr: { 291 auto *LLVMGEP = cast<llvm::GetElementPtrInst>(LLVMV); 292 It->second = std::unique_ptr<GetElementPtrInst>( 293 new GetElementPtrInst(LLVMGEP, *this)); 294 return It->second.get(); 295 } 296 case llvm::Instruction::CatchSwitch: { 297 auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(LLVMV); 298 It->second = std::unique_ptr<CatchSwitchInst>( 299 new CatchSwitchInst(LLVMCatchSwitchInst, *this)); 300 return It->second.get(); 301 } 302 case llvm::Instruction::Resume: { 303 auto *LLVMResumeInst = cast<llvm::ResumeInst>(LLVMV); 304 It->second = 305 std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this)); 306 return It->second.get(); 307 } 308 case llvm::Instruction::Switch: { 309 auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV); 310 It->second = 311 std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this)); 312 return It->second.get(); 313 } 314 case llvm::Instruction::FNeg: { 315 auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV); 316 It->second = std::unique_ptr<UnaryOperator>( 317 new UnaryOperator(LLVMUnaryOperator, *this)); 318 return It->second.get(); 319 } 320 case llvm::Instruction::Add: 321 case llvm::Instruction::FAdd: 322 case llvm::Instruction::Sub: 323 case llvm::Instruction::FSub: 324 case llvm::Instruction::Mul: 325 case llvm::Instruction::FMul: 326 case llvm::Instruction::UDiv: 327 case llvm::Instruction::SDiv: 328 case llvm::Instruction::FDiv: 329 case llvm::Instruction::URem: 330 case llvm::Instruction::SRem: 331 case llvm::Instruction::FRem: 332 case llvm::Instruction::Shl: 333 case llvm::Instruction::LShr: 334 case llvm::Instruction::AShr: 335 case llvm::Instruction::And: 336 case llvm::Instruction::Or: 337 case llvm::Instruction::Xor: { 338 auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(LLVMV); 339 It->second = std::unique_ptr<BinaryOperator>( 340 new BinaryOperator(LLVMBinaryOperator, *this)); 341 return It->second.get(); 342 } 343 case llvm::Instruction::AtomicRMW: { 344 auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(LLVMV); 345 It->second = 346 std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(LLVMAtomicRMW, *this)); 347 return It->second.get(); 348 } 349 case llvm::Instruction::AtomicCmpXchg: { 350 auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV); 351 It->second = std::unique_ptr<AtomicCmpXchgInst>( 352 new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this)); 353 return It->second.get(); 354 } 355 case llvm::Instruction::Alloca: { 356 auto *LLVMAlloca = cast<llvm::AllocaInst>(LLVMV); 357 It->second = std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this)); 358 return It->second.get(); 359 } 360 case llvm::Instruction::ZExt: 361 case llvm::Instruction::SExt: 362 case llvm::Instruction::FPToUI: 363 case llvm::Instruction::FPToSI: 364 case llvm::Instruction::FPExt: 365 case llvm::Instruction::PtrToInt: 366 case llvm::Instruction::IntToPtr: 367 case llvm::Instruction::SIToFP: 368 case llvm::Instruction::UIToFP: 369 case llvm::Instruction::Trunc: 370 case llvm::Instruction::FPTrunc: 371 case llvm::Instruction::BitCast: 372 case llvm::Instruction::AddrSpaceCast: { 373 auto *LLVMCast = cast<llvm::CastInst>(LLVMV); 374 It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this)); 375 return It->second.get(); 376 } 377 case llvm::Instruction::PHI: { 378 auto *LLVMPhi = cast<llvm::PHINode>(LLVMV); 379 It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this)); 380 return It->second.get(); 381 } 382 case llvm::Instruction::ICmp: { 383 auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV); 384 It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this)); 385 return It->second.get(); 386 } 387 case llvm::Instruction::FCmp: { 388 auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV); 389 It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this)); 390 return It->second.get(); 391 } 392 case llvm::Instruction::Unreachable: { 393 auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV); 394 It->second = std::unique_ptr<UnreachableInst>( 395 new UnreachableInst(LLVMUnreachable, *this)); 396 return It->second.get(); 397 } 398 default: 399 break; 400 } 401 402 It->second = std::unique_ptr<OpaqueInst>( 403 new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this)); 404 return It->second.get(); 405 } 406 407 Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) { 408 auto Pair = LLVMValueToValueMap.insert({LLVMArg, nullptr}); 409 auto It = Pair.first; 410 if (Pair.second) { 411 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this)); 412 return cast<Argument>(It->second.get()); 413 } 414 return cast<Argument>(It->second.get()); 415 } 416 417 Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) { 418 return cast<Constant>(getOrCreateValueInternal(LLVMC, 0)); 419 } 420 421 BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) { 422 assert(getValue(LLVMBB) == nullptr && "Already exists!"); 423 auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this)); 424 auto *BB = cast<BasicBlock>(registerValue(std::move(NewBBPtr))); 425 // Create SandboxIR for BB's body. 426 BB->buildBasicBlockFromLLVMIR(LLVMBB); 427 return BB; 428 } 429 430 VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) { 431 auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this)); 432 return cast<VAArgInst>(registerValue(std::move(NewPtr))); 433 } 434 435 FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) { 436 auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this)); 437 return cast<FreezeInst>(registerValue(std::move(NewPtr))); 438 } 439 440 FenceInst *Context::createFenceInst(llvm::FenceInst *SI) { 441 auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this)); 442 return cast<FenceInst>(registerValue(std::move(NewPtr))); 443 } 444 445 SelectInst *Context::createSelectInst(llvm::SelectInst *SI) { 446 auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this)); 447 return cast<SelectInst>(registerValue(std::move(NewPtr))); 448 } 449 450 ExtractElementInst * 451 Context::createExtractElementInst(llvm::ExtractElementInst *EEI) { 452 auto NewPtr = 453 std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this)); 454 return cast<ExtractElementInst>(registerValue(std::move(NewPtr))); 455 } 456 457 InsertElementInst * 458 Context::createInsertElementInst(llvm::InsertElementInst *IEI) { 459 auto NewPtr = 460 std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this)); 461 return cast<InsertElementInst>(registerValue(std::move(NewPtr))); 462 } 463 464 ShuffleVectorInst * 465 Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) { 466 auto NewPtr = 467 std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this)); 468 return cast<ShuffleVectorInst>(registerValue(std::move(NewPtr))); 469 } 470 471 ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) { 472 auto NewPtr = 473 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this)); 474 return cast<ExtractValueInst>(registerValue(std::move(NewPtr))); 475 } 476 477 InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) { 478 auto NewPtr = 479 std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this)); 480 return cast<InsertValueInst>(registerValue(std::move(NewPtr))); 481 } 482 483 BranchInst *Context::createBranchInst(llvm::BranchInst *BI) { 484 auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this)); 485 return cast<BranchInst>(registerValue(std::move(NewPtr))); 486 } 487 488 LoadInst *Context::createLoadInst(llvm::LoadInst *LI) { 489 auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this)); 490 return cast<LoadInst>(registerValue(std::move(NewPtr))); 491 } 492 493 StoreInst *Context::createStoreInst(llvm::StoreInst *SI) { 494 auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this)); 495 return cast<StoreInst>(registerValue(std::move(NewPtr))); 496 } 497 498 ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) { 499 auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this)); 500 return cast<ReturnInst>(registerValue(std::move(NewPtr))); 501 } 502 503 CallInst *Context::createCallInst(llvm::CallInst *I) { 504 auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this)); 505 return cast<CallInst>(registerValue(std::move(NewPtr))); 506 } 507 508 InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) { 509 auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this)); 510 return cast<InvokeInst>(registerValue(std::move(NewPtr))); 511 } 512 513 CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) { 514 auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this)); 515 return cast<CallBrInst>(registerValue(std::move(NewPtr))); 516 } 517 518 UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) { 519 auto NewPtr = 520 std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this)); 521 return cast<UnreachableInst>(registerValue(std::move(NewPtr))); 522 } 523 LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) { 524 auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this)); 525 return cast<LandingPadInst>(registerValue(std::move(NewPtr))); 526 } 527 CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) { 528 auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this)); 529 return cast<CatchPadInst>(registerValue(std::move(NewPtr))); 530 } 531 CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) { 532 auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this)); 533 return cast<CleanupPadInst>(registerValue(std::move(NewPtr))); 534 } 535 CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) { 536 auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this)); 537 return cast<CatchReturnInst>(registerValue(std::move(NewPtr))); 538 } 539 CleanupReturnInst * 540 Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) { 541 auto NewPtr = 542 std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this)); 543 return cast<CleanupReturnInst>(registerValue(std::move(NewPtr))); 544 } 545 GetElementPtrInst * 546 Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) { 547 auto NewPtr = 548 std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this)); 549 return cast<GetElementPtrInst>(registerValue(std::move(NewPtr))); 550 } 551 CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) { 552 auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this)); 553 return cast<CatchSwitchInst>(registerValue(std::move(NewPtr))); 554 } 555 ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) { 556 auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this)); 557 return cast<ResumeInst>(registerValue(std::move(NewPtr))); 558 } 559 SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) { 560 auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this)); 561 return cast<SwitchInst>(registerValue(std::move(NewPtr))); 562 } 563 UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) { 564 auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this)); 565 return cast<UnaryOperator>(registerValue(std::move(NewPtr))); 566 } 567 BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) { 568 auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this)); 569 return cast<BinaryOperator>(registerValue(std::move(NewPtr))); 570 } 571 AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) { 572 auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this)); 573 return cast<AtomicRMWInst>(registerValue(std::move(NewPtr))); 574 } 575 AtomicCmpXchgInst * 576 Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) { 577 auto NewPtr = 578 std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this)); 579 return cast<AtomicCmpXchgInst>(registerValue(std::move(NewPtr))); 580 } 581 AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) { 582 auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this)); 583 return cast<AllocaInst>(registerValue(std::move(NewPtr))); 584 } 585 CastInst *Context::createCastInst(llvm::CastInst *I) { 586 auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this)); 587 return cast<CastInst>(registerValue(std::move(NewPtr))); 588 } 589 PHINode *Context::createPHINode(llvm::PHINode *I) { 590 auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this)); 591 return cast<PHINode>(registerValue(std::move(NewPtr))); 592 } 593 ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) { 594 auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this)); 595 return cast<ICmpInst>(registerValue(std::move(NewPtr))); 596 } 597 FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) { 598 auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this)); 599 return cast<FCmpInst>(registerValue(std::move(NewPtr))); 600 } 601 Value *Context::getValue(llvm::Value *V) const { 602 auto It = LLVMValueToValueMap.find(V); 603 if (It != LLVMValueToValueMap.end()) 604 return It->second.get(); 605 return nullptr; 606 } 607 608 Context::Context(LLVMContext &LLVMCtx) 609 : LLVMCtx(LLVMCtx), IRTracker(*this), 610 LLVMIRBuilder(LLVMCtx, ConstantFolder()) {} 611 612 Context::~Context() {} 613 614 void Context::clear() { 615 // TODO: Ideally we should clear only function-scope objects, and keep global 616 // objects, like Constants to avoid recreating them. 617 LLVMValueToValueMap.clear(); 618 } 619 620 Module *Context::getModule(llvm::Module *LLVMM) const { 621 auto It = LLVMModuleToModuleMap.find(LLVMM); 622 if (It != LLVMModuleToModuleMap.end()) 623 return It->second.get(); 624 return nullptr; 625 } 626 627 Module *Context::getOrCreateModule(llvm::Module *LLVMM) { 628 auto Pair = LLVMModuleToModuleMap.insert({LLVMM, nullptr}); 629 auto It = Pair.first; 630 if (!Pair.second) 631 return It->second.get(); 632 It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this)); 633 return It->second.get(); 634 } 635 636 Function *Context::createFunction(llvm::Function *F) { 637 // Create the module if needed before we create the new sandboxir::Function. 638 // Note: this won't fully populate the module. The only globals that will be 639 // available will be the ones being used within the function. 640 getOrCreateModule(F->getParent()); 641 642 // There may be a function declaration already defined. Regardless destroy it. 643 if (Function *ExistingF = cast_or_null<Function>(getValue(F))) 644 detach(ExistingF); 645 646 auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this)); 647 auto *SBF = cast<Function>(registerValue(std::move(NewFPtr))); 648 // Create arguments. 649 for (auto &Arg : F->args()) 650 getOrCreateArgument(&Arg); 651 // Create BBs. 652 for (auto &BB : *F) 653 createBasicBlock(&BB); 654 return SBF; 655 } 656 657 Module *Context::createModule(llvm::Module *LLVMM) { 658 auto *M = getOrCreateModule(LLVMM); 659 // Create the functions. 660 for (auto &LLVMF : *LLVMM) 661 createFunction(&LLVMF); 662 // Create globals. 663 for (auto &Global : LLVMM->globals()) 664 getOrCreateValue(&Global); 665 // Create aliases. 666 for (auto &Alias : LLVMM->aliases()) 667 getOrCreateValue(&Alias); 668 // Create ifuncs. 669 for (auto &IFunc : LLVMM->ifuncs()) 670 getOrCreateValue(&IFunc); 671 672 return M; 673 } 674 675 void Context::runEraseInstrCallbacks(Instruction *I) { 676 for (const auto &CBEntry : EraseInstrCallbacks) 677 CBEntry.second(I); 678 } 679 680 void Context::runCreateInstrCallbacks(Instruction *I) { 681 for (auto &CBEntry : CreateInstrCallbacks) 682 CBEntry.second(I); 683 } 684 685 void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { 686 for (auto &CBEntry : MoveInstrCallbacks) 687 CBEntry.second(I, WhereIt); 688 } 689 690 // An arbitrary limit, to check for accidental misuse. We expect a small number 691 // of callbacks to be registered at a time, but we can increase this number if 692 // we discover we needed more. 693 [[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16; 694 695 Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) { 696 assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks && 697 "EraseInstrCallbacks size limit exceeded"); 698 CallbackID ID{NextCallbackID++}; 699 EraseInstrCallbacks[ID] = CB; 700 return ID; 701 } 702 void Context::unregisterEraseInstrCallback(CallbackID ID) { 703 [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID); 704 assert(Erased && 705 "Callback ID not found in EraseInstrCallbacks during deregistration"); 706 } 707 708 Context::CallbackID 709 Context::registerCreateInstrCallback(CreateInstrCallback CB) { 710 assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks && 711 "CreateInstrCallbacks size limit exceeded"); 712 CallbackID ID{NextCallbackID++}; 713 CreateInstrCallbacks[ID] = CB; 714 return ID; 715 } 716 void Context::unregisterCreateInstrCallback(CallbackID ID) { 717 [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID); 718 assert(Erased && 719 "Callback ID not found in CreateInstrCallbacks during deregistration"); 720 } 721 722 Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) { 723 assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks && 724 "MoveInstrCallbacks size limit exceeded"); 725 CallbackID ID{NextCallbackID++}; 726 MoveInstrCallbacks[ID] = CB; 727 return ID; 728 } 729 void Context::unregisterMoveInstrCallback(CallbackID ID) { 730 [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID); 731 assert(Erased && 732 "Callback ID not found in MoveInstrCallbacks during deregistration"); 733 } 734 735 } // namespace llvm::sandboxir 736