xref: /llvm-project/llvm/lib/SandboxIR/Tracker.cpp (revision 749443a307e8e47a25a5552cbeb27f69845e6ce8)
1 //===- Tracker.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/SandboxIR/Tracker.h"
10 #include "llvm/ADT/STLExtras.h"
11 #include "llvm/IR/BasicBlock.h"
12 #include "llvm/IR/Instruction.h"
13 #include "llvm/IR/Module.h"
14 #include "llvm/IR/StructuralHash.h"
15 #include "llvm/SandboxIR/Instruction.h"
16 #include <sstream>
17 
18 using namespace llvm::sandboxir;
19 
20 #ifndef NDEBUG
21 
22 std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const {
23   std::string Result;
24   raw_string_ostream SS(Result);
25   F.print(SS, /*AssemblyAnnotationWriter=*/nullptr);
26   return Result;
27 }
28 
29 IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const {
30   ContextSnapshot Result;
31   for (const auto &Entry : Ctx.LLVMModuleToModuleMap)
32     for (const auto &F : *Entry.first) {
33       FunctionSnapshot Snapshot;
34       Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true);
35       Snapshot.TextualIR = dumpIR(F);
36       Result[&F] = Snapshot;
37     }
38   return Result;
39 }
40 
41 bool IRSnapshotChecker::diff(const ContextSnapshot &Orig,
42                              const ContextSnapshot &Curr) const {
43   bool DifferenceFound = false;
44   for (const auto &[F, OrigFS] : Orig) {
45     auto CurrFSIt = Curr.find(F);
46     if (CurrFSIt == Curr.end()) {
47       DifferenceFound = true;
48       dbgs() << "Function " << F->getName() << " not found in current IR.\n";
49       dbgs() << OrigFS.TextualIR << "\n";
50       continue;
51     }
52     const FunctionSnapshot &CurrFS = CurrFSIt->second;
53     if (OrigFS.Hash != CurrFS.Hash) {
54       DifferenceFound = true;
55       dbgs() << "Found IR difference in Function " << F->getName() << "\n";
56       dbgs() << "Original:\n" << OrigFS.TextualIR << "\n";
57       dbgs() << "Current:\n" << CurrFS.TextualIR << "\n";
58     }
59   }
60   // Check that Curr doesn't contain any new functions.
61   for (const auto &[F, CurrFS] : Curr) {
62     if (!Orig.contains(F)) {
63       DifferenceFound = true;
64       dbgs() << "Function " << F->getName()
65              << " found in current IR but not in original snapshot.\n";
66       dbgs() << CurrFS.TextualIR << "\n";
67     }
68   }
69   return DifferenceFound;
70 }
71 
72 void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); }
73 
74 void IRSnapshotChecker::expectNoDiff() {
75   ContextSnapshot CurrContextSnapshot = takeSnapshot();
76   if (diff(OrigContextSnapshot, CurrContextSnapshot)) {
77     llvm_unreachable(
78         "Original and current IR differ! Probably a checkpointing bug.");
79   }
80 }
81 
82 void UseSet::dump() const {
83   dump(dbgs());
84   dbgs() << "\n";
85 }
86 
87 void UseSwap::dump() const {
88   dump(dbgs());
89   dbgs() << "\n";
90 }
91 #endif // NDEBUG
92 
93 PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx)
94     : PHI(PHI), RemovedIdx(RemovedIdx) {
95   RemovedV = PHI->getIncomingValue(RemovedIdx);
96   RemovedBB = PHI->getIncomingBlock(RemovedIdx);
97 }
98 
99 void PHIRemoveIncoming::revert(Tracker &Tracker) {
100   // Special case: if the PHI is now empty, as we don't need to care about the
101   // order of the incoming values.
102   unsigned NumIncoming = PHI->getNumIncomingValues();
103   if (NumIncoming == 0) {
104     PHI->addIncoming(RemovedV, RemovedBB);
105     return;
106   }
107   // Shift all incoming values by one starting from the end until `Idx`.
108   // Start by adding a copy of the last incoming values.
109   unsigned LastIdx = NumIncoming - 1;
110   PHI->addIncoming(PHI->getIncomingValue(LastIdx),
111                    PHI->getIncomingBlock(LastIdx));
112   for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) {
113     auto *PrevV = PHI->getIncomingValue(Idx - 1);
114     auto *PrevBB = PHI->getIncomingBlock(Idx - 1);
115     PHI->setIncomingValue(Idx, PrevV);
116     PHI->setIncomingBlock(Idx, PrevBB);
117   }
118   PHI->setIncomingValue(RemovedIdx, RemovedV);
119   PHI->setIncomingBlock(RemovedIdx, RemovedBB);
120 }
121 
122 #ifndef NDEBUG
123 void PHIRemoveIncoming::dump() const {
124   dump(dbgs());
125   dbgs() << "\n";
126 }
127 #endif // NDEBUG
128 
129 PHIAddIncoming::PHIAddIncoming(PHINode *PHI)
130     : PHI(PHI), Idx(PHI->getNumIncomingValues()) {}
131 
132 void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); }
133 
134 #ifndef NDEBUG
135 void PHIAddIncoming::dump() const {
136   dump(dbgs());
137   dbgs() << "\n";
138 }
139 #endif // NDEBUG
140 
141 Tracker::~Tracker() {
142   assert(Changes.empty() && "You must accept or revert changes!");
143 }
144 
145 EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr)
146     : ErasedIPtr(std::move(ErasedIPtr)) {
147   auto *I = cast<Instruction>(this->ErasedIPtr.get());
148   auto LLVMInstrs = I->getLLVMInstrs();
149   // Iterate in reverse program order.
150   for (auto *LLVMI : reverse(LLVMInstrs)) {
151     SmallVector<llvm::Value *> Operands;
152     Operands.reserve(LLVMI->getNumOperands());
153     for (auto [OpNum, Use] : enumerate(LLVMI->operands()))
154       Operands.push_back(Use.get());
155     InstrData.push_back({Operands, LLVMI});
156   }
157   assert(is_sorted(InstrData,
158                    [](const auto &D0, const auto &D1) {
159                      return D0.LLVMI->comesBefore(D1.LLVMI);
160                    }) &&
161          "Expected reverse program order!");
162   auto *BotLLVMI = cast<llvm::Instruction>(I->Val);
163   if (BotLLVMI->getNextNode() != nullptr)
164     NextLLVMIOrBB = BotLLVMI->getNextNode();
165   else
166     NextLLVMIOrBB = BotLLVMI->getParent();
167 }
168 
169 void EraseFromParent::accept() {
170   for (const auto &IData : InstrData)
171     IData.LLVMI->deleteValue();
172 }
173 
174 void EraseFromParent::revert(Tracker &Tracker) {
175   // Place the bottom-most instruction first.
176   auto [Operands, BotLLVMI] = InstrData[0];
177   if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(NextLLVMIOrBB)) {
178     BotLLVMI->insertBefore(NextLLVMI->getIterator());
179   } else {
180     auto *LLVMBB = cast<llvm::BasicBlock *>(NextLLVMIOrBB);
181     BotLLVMI->insertInto(LLVMBB, LLVMBB->end());
182   }
183   for (auto [OpNum, Op] : enumerate(Operands))
184     BotLLVMI->setOperand(OpNum, Op);
185 
186   // Go over the rest of the instructions and stack them on top.
187   for (auto [Operands, LLVMI] : drop_begin(InstrData)) {
188     LLVMI->insertBefore(BotLLVMI->getIterator());
189     for (auto [OpNum, Op] : enumerate(Operands))
190       LLVMI->setOperand(OpNum, Op);
191     BotLLVMI = LLVMI;
192   }
193   Tracker.getContext().registerValue(std::move(ErasedIPtr));
194 }
195 
196 #ifndef NDEBUG
197 void EraseFromParent::dump() const {
198   dump(dbgs());
199   dbgs() << "\n";
200 }
201 #endif // NDEBUG
202 
203 RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) {
204   if (auto *NextI = RemovedI->getNextNode())
205     NextInstrOrBB = NextI;
206   else
207     NextInstrOrBB = RemovedI->getParent();
208 }
209 
210 void RemoveFromParent::revert(Tracker &Tracker) {
211   if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) {
212     RemovedI->insertBefore(NextI);
213   } else {
214     auto *BB = cast<BasicBlock *>(NextInstrOrBB);
215     RemovedI->insertInto(BB, BB->end());
216   }
217 }
218 
219 #ifndef NDEBUG
220 void RemoveFromParent::dump() const {
221   dump(dbgs());
222   dbgs() << "\n";
223 }
224 #endif
225 
226 CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI)
227     : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {}
228 
229 void CatchSwitchAddHandler::revert(Tracker &Tracker) {
230   // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler()
231   // once it gets implemented.
232   auto *LLVMCSI = cast<llvm::CatchSwitchInst>(CSI->Val);
233   LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx);
234 }
235 
236 SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) {
237   for (const auto &C : Switch->cases())
238     Cases.push_back({C.getCaseValue(), C.getCaseSuccessor()});
239 }
240 
241 void SwitchRemoveCase::revert(Tracker &Tracker) {
242   // SwitchInst::removeCase doesn't provide any guarantees about the order of
243   // cases after removal. In order to preserve the original ordering, we save
244   // all of them and, when reverting, clear them all then insert them in the
245   // desired order. This still relies on the fact that `addCase` will insert
246   // them at the end, but it is documented to invalidate `case_end()` so it's
247   // probably okay.
248   unsigned NumCases = Switch->getNumCases();
249   for (unsigned I = 0; I < NumCases; ++I)
250     Switch->removeCase(Switch->case_begin());
251   for (auto &Case : Cases)
252     Switch->addCase(Case.Val, Case.Dest);
253 }
254 
255 #ifndef NDEBUG
256 void SwitchRemoveCase::dump() const {
257   dump(dbgs());
258   dbgs() << "\n";
259 }
260 #endif // NDEBUG
261 
262 void SwitchAddCase::revert(Tracker &Tracker) {
263   auto It = Switch->findCaseValue(Val);
264   Switch->removeCase(It);
265 }
266 
267 #ifndef NDEBUG
268 void SwitchAddCase::dump() const {
269   dump(dbgs());
270   dbgs() << "\n";
271 }
272 #endif // NDEBUG
273 
274 MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) {
275   if (auto *NextI = MovedI->getNextNode())
276     NextInstrOrBB = NextI;
277   else
278     NextInstrOrBB = MovedI->getParent();
279 }
280 
281 void MoveInstr::revert(Tracker &Tracker) {
282   if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) {
283     MovedI->moveBefore(NextI);
284   } else {
285     auto *BB = cast<BasicBlock *>(NextInstrOrBB);
286     MovedI->moveBefore(*BB, BB->end());
287   }
288 }
289 
290 #ifndef NDEBUG
291 void MoveInstr::dump() const {
292   dump(dbgs());
293   dbgs() << "\n";
294 }
295 #endif
296 
297 void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); }
298 
299 InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {}
300 
301 #ifndef NDEBUG
302 void InsertIntoBB::dump() const {
303   dump(dbgs());
304   dbgs() << "\n";
305 }
306 #endif
307 
308 void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); }
309 
310 #ifndef NDEBUG
311 void CreateAndInsertInst::dump() const {
312   dump(dbgs());
313   dbgs() << "\n";
314 }
315 #endif
316 
317 ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI)
318     : SVI(SVI), PrevMask(SVI->getShuffleMask()) {}
319 
320 void ShuffleVectorSetMask::revert(Tracker &Tracker) {
321   SVI->setShuffleMask(PrevMask);
322 }
323 
324 #ifndef NDEBUG
325 void ShuffleVectorSetMask::dump() const {
326   dump(dbgs());
327   dbgs() << "\n";
328 }
329 #endif
330 
331 CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
332 
333 void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
334 #ifndef NDEBUG
335 void CmpSwapOperands::dump() const {
336   dump(dbgs());
337   dbgs() << "\n";
338 }
339 #endif
340 
341 void Tracker::save() {
342   State = TrackerState::Record;
343 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
344   SnapshotChecker.save();
345 #endif
346 }
347 
348 void Tracker::revert() {
349   assert(State == TrackerState::Record && "Forgot to save()!");
350   State = TrackerState::Disabled;
351   for (auto &Change : reverse(Changes))
352     Change->revert(*this);
353   Changes.clear();
354 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
355   SnapshotChecker.expectNoDiff();
356 #endif
357 }
358 
359 void Tracker::accept() {
360   assert(State == TrackerState::Record && "Forgot to save()!");
361   State = TrackerState::Disabled;
362   for (auto &Change : Changes)
363     Change->accept();
364   Changes.clear();
365 }
366 
367 #ifndef NDEBUG
368 void Tracker::dump(raw_ostream &OS) const {
369   for (auto [Idx, ChangePtr] : enumerate(Changes)) {
370     OS << Idx << ". ";
371     ChangePtr->dump(OS);
372     OS << "\n";
373   }
374 }
375 void Tracker::dump() const {
376   dump(dbgs());
377   dbgs() << "\n";
378 }
379 #endif // NDEBUG
380