1 //===- Tracker.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/SandboxIR/Tracker.h" 10 #include "llvm/ADT/STLExtras.h" 11 #include "llvm/IR/BasicBlock.h" 12 #include "llvm/IR/Instruction.h" 13 #include "llvm/IR/Module.h" 14 #include "llvm/IR/StructuralHash.h" 15 #include "llvm/SandboxIR/Instruction.h" 16 #include <sstream> 17 18 using namespace llvm::sandboxir; 19 20 #ifndef NDEBUG 21 22 std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const { 23 std::string Result; 24 raw_string_ostream SS(Result); 25 F.print(SS, /*AssemblyAnnotationWriter=*/nullptr); 26 return Result; 27 } 28 29 IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const { 30 ContextSnapshot Result; 31 for (const auto &Entry : Ctx.LLVMModuleToModuleMap) 32 for (const auto &F : *Entry.first) { 33 FunctionSnapshot Snapshot; 34 Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true); 35 Snapshot.TextualIR = dumpIR(F); 36 Result[&F] = Snapshot; 37 } 38 return Result; 39 } 40 41 bool IRSnapshotChecker::diff(const ContextSnapshot &Orig, 42 const ContextSnapshot &Curr) const { 43 bool DifferenceFound = false; 44 for (const auto &[F, OrigFS] : Orig) { 45 auto CurrFSIt = Curr.find(F); 46 if (CurrFSIt == Curr.end()) { 47 DifferenceFound = true; 48 dbgs() << "Function " << F->getName() << " not found in current IR.\n"; 49 dbgs() << OrigFS.TextualIR << "\n"; 50 continue; 51 } 52 const FunctionSnapshot &CurrFS = CurrFSIt->second; 53 if (OrigFS.Hash != CurrFS.Hash) { 54 DifferenceFound = true; 55 dbgs() << "Found IR difference in Function " << F->getName() << "\n"; 56 dbgs() << "Original:\n" << OrigFS.TextualIR << "\n"; 57 dbgs() << "Current:\n" << CurrFS.TextualIR << "\n"; 58 } 59 } 60 // Check that Curr doesn't contain any new functions. 61 for (const auto &[F, CurrFS] : Curr) { 62 if (!Orig.contains(F)) { 63 DifferenceFound = true; 64 dbgs() << "Function " << F->getName() 65 << " found in current IR but not in original snapshot.\n"; 66 dbgs() << CurrFS.TextualIR << "\n"; 67 } 68 } 69 return DifferenceFound; 70 } 71 72 void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); } 73 74 void IRSnapshotChecker::expectNoDiff() { 75 ContextSnapshot CurrContextSnapshot = takeSnapshot(); 76 if (diff(OrigContextSnapshot, CurrContextSnapshot)) { 77 llvm_unreachable( 78 "Original and current IR differ! Probably a checkpointing bug."); 79 } 80 } 81 82 void UseSet::dump() const { 83 dump(dbgs()); 84 dbgs() << "\n"; 85 } 86 87 void UseSwap::dump() const { 88 dump(dbgs()); 89 dbgs() << "\n"; 90 } 91 #endif // NDEBUG 92 93 PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx) 94 : PHI(PHI), RemovedIdx(RemovedIdx) { 95 RemovedV = PHI->getIncomingValue(RemovedIdx); 96 RemovedBB = PHI->getIncomingBlock(RemovedIdx); 97 } 98 99 void PHIRemoveIncoming::revert(Tracker &Tracker) { 100 // Special case: if the PHI is now empty, as we don't need to care about the 101 // order of the incoming values. 102 unsigned NumIncoming = PHI->getNumIncomingValues(); 103 if (NumIncoming == 0) { 104 PHI->addIncoming(RemovedV, RemovedBB); 105 return; 106 } 107 // Shift all incoming values by one starting from the end until `Idx`. 108 // Start by adding a copy of the last incoming values. 109 unsigned LastIdx = NumIncoming - 1; 110 PHI->addIncoming(PHI->getIncomingValue(LastIdx), 111 PHI->getIncomingBlock(LastIdx)); 112 for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) { 113 auto *PrevV = PHI->getIncomingValue(Idx - 1); 114 auto *PrevBB = PHI->getIncomingBlock(Idx - 1); 115 PHI->setIncomingValue(Idx, PrevV); 116 PHI->setIncomingBlock(Idx, PrevBB); 117 } 118 PHI->setIncomingValue(RemovedIdx, RemovedV); 119 PHI->setIncomingBlock(RemovedIdx, RemovedBB); 120 } 121 122 #ifndef NDEBUG 123 void PHIRemoveIncoming::dump() const { 124 dump(dbgs()); 125 dbgs() << "\n"; 126 } 127 #endif // NDEBUG 128 129 PHIAddIncoming::PHIAddIncoming(PHINode *PHI) 130 : PHI(PHI), Idx(PHI->getNumIncomingValues()) {} 131 132 void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); } 133 134 #ifndef NDEBUG 135 void PHIAddIncoming::dump() const { 136 dump(dbgs()); 137 dbgs() << "\n"; 138 } 139 #endif // NDEBUG 140 141 Tracker::~Tracker() { 142 assert(Changes.empty() && "You must accept or revert changes!"); 143 } 144 145 EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr) 146 : ErasedIPtr(std::move(ErasedIPtr)) { 147 auto *I = cast<Instruction>(this->ErasedIPtr.get()); 148 auto LLVMInstrs = I->getLLVMInstrs(); 149 // Iterate in reverse program order. 150 for (auto *LLVMI : reverse(LLVMInstrs)) { 151 SmallVector<llvm::Value *> Operands; 152 Operands.reserve(LLVMI->getNumOperands()); 153 for (auto [OpNum, Use] : enumerate(LLVMI->operands())) 154 Operands.push_back(Use.get()); 155 InstrData.push_back({Operands, LLVMI}); 156 } 157 assert(is_sorted(InstrData, 158 [](const auto &D0, const auto &D1) { 159 return D0.LLVMI->comesBefore(D1.LLVMI); 160 }) && 161 "Expected reverse program order!"); 162 auto *BotLLVMI = cast<llvm::Instruction>(I->Val); 163 if (BotLLVMI->getNextNode() != nullptr) 164 NextLLVMIOrBB = BotLLVMI->getNextNode(); 165 else 166 NextLLVMIOrBB = BotLLVMI->getParent(); 167 } 168 169 void EraseFromParent::accept() { 170 for (const auto &IData : InstrData) 171 IData.LLVMI->deleteValue(); 172 } 173 174 void EraseFromParent::revert(Tracker &Tracker) { 175 // Place the bottom-most instruction first. 176 auto [Operands, BotLLVMI] = InstrData[0]; 177 if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(NextLLVMIOrBB)) { 178 BotLLVMI->insertBefore(NextLLVMI->getIterator()); 179 } else { 180 auto *LLVMBB = cast<llvm::BasicBlock *>(NextLLVMIOrBB); 181 BotLLVMI->insertInto(LLVMBB, LLVMBB->end()); 182 } 183 for (auto [OpNum, Op] : enumerate(Operands)) 184 BotLLVMI->setOperand(OpNum, Op); 185 186 // Go over the rest of the instructions and stack them on top. 187 for (auto [Operands, LLVMI] : drop_begin(InstrData)) { 188 LLVMI->insertBefore(BotLLVMI->getIterator()); 189 for (auto [OpNum, Op] : enumerate(Operands)) 190 LLVMI->setOperand(OpNum, Op); 191 BotLLVMI = LLVMI; 192 } 193 Tracker.getContext().registerValue(std::move(ErasedIPtr)); 194 } 195 196 #ifndef NDEBUG 197 void EraseFromParent::dump() const { 198 dump(dbgs()); 199 dbgs() << "\n"; 200 } 201 #endif // NDEBUG 202 203 RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) { 204 if (auto *NextI = RemovedI->getNextNode()) 205 NextInstrOrBB = NextI; 206 else 207 NextInstrOrBB = RemovedI->getParent(); 208 } 209 210 void RemoveFromParent::revert(Tracker &Tracker) { 211 if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) { 212 RemovedI->insertBefore(NextI); 213 } else { 214 auto *BB = cast<BasicBlock *>(NextInstrOrBB); 215 RemovedI->insertInto(BB, BB->end()); 216 } 217 } 218 219 #ifndef NDEBUG 220 void RemoveFromParent::dump() const { 221 dump(dbgs()); 222 dbgs() << "\n"; 223 } 224 #endif 225 226 CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI) 227 : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {} 228 229 void CatchSwitchAddHandler::revert(Tracker &Tracker) { 230 // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler() 231 // once it gets implemented. 232 auto *LLVMCSI = cast<llvm::CatchSwitchInst>(CSI->Val); 233 LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx); 234 } 235 236 SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) { 237 for (const auto &C : Switch->cases()) 238 Cases.push_back({C.getCaseValue(), C.getCaseSuccessor()}); 239 } 240 241 void SwitchRemoveCase::revert(Tracker &Tracker) { 242 // SwitchInst::removeCase doesn't provide any guarantees about the order of 243 // cases after removal. In order to preserve the original ordering, we save 244 // all of them and, when reverting, clear them all then insert them in the 245 // desired order. This still relies on the fact that `addCase` will insert 246 // them at the end, but it is documented to invalidate `case_end()` so it's 247 // probably okay. 248 unsigned NumCases = Switch->getNumCases(); 249 for (unsigned I = 0; I < NumCases; ++I) 250 Switch->removeCase(Switch->case_begin()); 251 for (auto &Case : Cases) 252 Switch->addCase(Case.Val, Case.Dest); 253 } 254 255 #ifndef NDEBUG 256 void SwitchRemoveCase::dump() const { 257 dump(dbgs()); 258 dbgs() << "\n"; 259 } 260 #endif // NDEBUG 261 262 void SwitchAddCase::revert(Tracker &Tracker) { 263 auto It = Switch->findCaseValue(Val); 264 Switch->removeCase(It); 265 } 266 267 #ifndef NDEBUG 268 void SwitchAddCase::dump() const { 269 dump(dbgs()); 270 dbgs() << "\n"; 271 } 272 #endif // NDEBUG 273 274 MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) { 275 if (auto *NextI = MovedI->getNextNode()) 276 NextInstrOrBB = NextI; 277 else 278 NextInstrOrBB = MovedI->getParent(); 279 } 280 281 void MoveInstr::revert(Tracker &Tracker) { 282 if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) { 283 MovedI->moveBefore(NextI); 284 } else { 285 auto *BB = cast<BasicBlock *>(NextInstrOrBB); 286 MovedI->moveBefore(*BB, BB->end()); 287 } 288 } 289 290 #ifndef NDEBUG 291 void MoveInstr::dump() const { 292 dump(dbgs()); 293 dbgs() << "\n"; 294 } 295 #endif 296 297 void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); } 298 299 InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {} 300 301 #ifndef NDEBUG 302 void InsertIntoBB::dump() const { 303 dump(dbgs()); 304 dbgs() << "\n"; 305 } 306 #endif 307 308 void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); } 309 310 #ifndef NDEBUG 311 void CreateAndInsertInst::dump() const { 312 dump(dbgs()); 313 dbgs() << "\n"; 314 } 315 #endif 316 317 ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI) 318 : SVI(SVI), PrevMask(SVI->getShuffleMask()) {} 319 320 void ShuffleVectorSetMask::revert(Tracker &Tracker) { 321 SVI->setShuffleMask(PrevMask); 322 } 323 324 #ifndef NDEBUG 325 void ShuffleVectorSetMask::dump() const { 326 dump(dbgs()); 327 dbgs() << "\n"; 328 } 329 #endif 330 331 CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {} 332 333 void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); } 334 #ifndef NDEBUG 335 void CmpSwapOperands::dump() const { 336 dump(dbgs()); 337 dbgs() << "\n"; 338 } 339 #endif 340 341 void Tracker::save() { 342 State = TrackerState::Record; 343 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) 344 SnapshotChecker.save(); 345 #endif 346 } 347 348 void Tracker::revert() { 349 assert(State == TrackerState::Record && "Forgot to save()!"); 350 State = TrackerState::Disabled; 351 for (auto &Change : reverse(Changes)) 352 Change->revert(*this); 353 Changes.clear(); 354 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) 355 SnapshotChecker.expectNoDiff(); 356 #endif 357 } 358 359 void Tracker::accept() { 360 assert(State == TrackerState::Record && "Forgot to save()!"); 361 State = TrackerState::Disabled; 362 for (auto &Change : Changes) 363 Change->accept(); 364 Changes.clear(); 365 } 366 367 #ifndef NDEBUG 368 void Tracker::dump(raw_ostream &OS) const { 369 for (auto [Idx, ChangePtr] : enumerate(Changes)) { 370 OS << Idx << ". "; 371 ChangePtr->dump(OS); 372 OS << "\n"; 373 } 374 } 375 void Tracker::dump() const { 376 dump(dbgs()); 377 dbgs() << "\n"; 378 } 379 #endif // NDEBUG 380