xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (revision 1c4341d176492da5f276937b84a3d0c959e4cf5b)
1318d2f5eSvporpo //===- DependencyGraph.cpp ------------------------------------------===//
2318d2f5eSvporpo //
3318d2f5eSvporpo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4318d2f5eSvporpo // See https://llvm.org/LICENSE.txt for license information.
5318d2f5eSvporpo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6318d2f5eSvporpo //
7318d2f5eSvporpo //===----------------------------------------------------------------------===//
8318d2f5eSvporpo 
9318d2f5eSvporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
100c9f7ef5Svporpo #include "llvm/ADT/ArrayRef.h"
11747d8f3fSvporpo #include "llvm/SandboxIR/Instruction.h"
1204a8bffdSvporpo #include "llvm/SandboxIR/Utils.h"
136e482148Svporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
14318d2f5eSvporpo 
1504a8bffdSvporpo namespace llvm::sandboxir {
16318d2f5eSvporpo 
17747d8f3fSvporpo PredIterator::value_type PredIterator::operator*() {
18747d8f3fSvporpo   // If it's a DGNode then we dereference the operand iterator.
19747d8f3fSvporpo   if (!isa<MemDGNode>(N)) {
20747d8f3fSvporpo     assert(OpIt != OpItE && "Can't dereference end iterator!");
21747d8f3fSvporpo     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
22747d8f3fSvporpo   }
23747d8f3fSvporpo   // It's a MemDGNode, so we check if we return either the use-def operand,
24747d8f3fSvporpo   // or a mem predecessor.
25747d8f3fSvporpo   if (OpIt != OpItE)
26747d8f3fSvporpo     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
27a4916d20Svporpo   // It's a MemDGNode with OpIt == end, so we need to use MemIt.
28a4916d20Svporpo   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
29747d8f3fSvporpo          "Cant' dereference end iterator!");
30747d8f3fSvporpo   return *MemIt;
31747d8f3fSvporpo }
32747d8f3fSvporpo 
33747d8f3fSvporpo PredIterator &PredIterator::operator++() {
34747d8f3fSvporpo   // If it's a DGNode then we increment the use-def iterator.
35747d8f3fSvporpo   if (!isa<MemDGNode>(N)) {
36747d8f3fSvporpo     assert(OpIt != OpItE && "Already at end!");
37747d8f3fSvporpo     ++OpIt;
38747d8f3fSvporpo     // Skip operands that are not instructions.
39747d8f3fSvporpo     OpIt = skipNonInstr(OpIt, OpItE);
40747d8f3fSvporpo     return *this;
41747d8f3fSvporpo   }
42747d8f3fSvporpo   // It's a MemDGNode, so if we are not at the end of the use-def iterator we
43747d8f3fSvporpo   // need to first increment that.
44747d8f3fSvporpo   if (OpIt != OpItE) {
45747d8f3fSvporpo     ++OpIt;
46747d8f3fSvporpo     // Skip operands that are not instructions.
47747d8f3fSvporpo     OpIt = skipNonInstr(OpIt, OpItE);
48747d8f3fSvporpo     return *this;
49747d8f3fSvporpo   }
50a4916d20Svporpo   // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
51a4916d20Svporpo   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
52747d8f3fSvporpo   ++MemIt;
53747d8f3fSvporpo   return *this;
54747d8f3fSvporpo }
55747d8f3fSvporpo 
56747d8f3fSvporpo bool PredIterator::operator==(const PredIterator &Other) const {
57747d8f3fSvporpo   assert(DAG == Other.DAG && "Iterators of different DAGs!");
58747d8f3fSvporpo   assert(N == Other.N && "Iterators of different nodes!");
59747d8f3fSvporpo   return OpIt == Other.OpIt && MemIt == Other.MemIt;
60747d8f3fSvporpo }
61747d8f3fSvporpo 
626e482148Svporpo DGNode::~DGNode() {
636e482148Svporpo   if (SB == nullptr)
646e482148Svporpo     return;
656e482148Svporpo   SB->eraseFromBundle(this);
666e482148Svporpo }
676e482148Svporpo 
68318d2f5eSvporpo #ifndef NDEBUG
69fc08ad66Svporpo void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
701d09925bSvporpo   OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n";
71a4916d20Svporpo }
72fc08ad66Svporpo void DGNode::dump() const { print(dbgs()); }
73a4916d20Svporpo void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
74fc08ad66Svporpo   DGNode::print(OS, false);
75318d2f5eSvporpo   if (PrintDeps) {
76318d2f5eSvporpo     // Print memory preds.
77318d2f5eSvporpo     static constexpr const unsigned Indent = 4;
78fc08ad66Svporpo     for (auto *Pred : MemPreds)
79fc08ad66Svporpo       OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n";
80318d2f5eSvporpo   }
81318d2f5eSvporpo }
82318d2f5eSvporpo #endif // NDEBUG
83318d2f5eSvporpo 
8469c00679Svporpo MemDGNode *
8569c00679Svporpo MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
8669c00679Svporpo                                           const DependencyGraph &DAG) {
8769c00679Svporpo   Instruction *I = Intvl.top();
8869c00679Svporpo   Instruction *BeforeI = Intvl.bottom();
8969c00679Svporpo   // Walk down the chain looking for a mem-dep candidate instruction.
9069c00679Svporpo   while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
9169c00679Svporpo     I = I->getNextNode();
9269c00679Svporpo   if (!DGNode::isMemDepNodeCandidate(I))
9369c00679Svporpo     return nullptr;
9469c00679Svporpo   return cast<MemDGNode>(DAG.getNode(I));
9569c00679Svporpo }
9669c00679Svporpo 
9769c00679Svporpo MemDGNode *
9869c00679Svporpo MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
9969c00679Svporpo                                           const DependencyGraph &DAG) {
10069c00679Svporpo   Instruction *I = Intvl.bottom();
10169c00679Svporpo   Instruction *AfterI = Intvl.top();
10269c00679Svporpo   // Walk up the chain looking for a mem-dep candidate instruction.
10369c00679Svporpo   while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
10469c00679Svporpo     I = I->getPrevNode();
10569c00679Svporpo   if (!DGNode::isMemDepNodeCandidate(I))
10669c00679Svporpo     return nullptr;
10769c00679Svporpo   return cast<MemDGNode>(DAG.getNode(I));
10869c00679Svporpo }
10969c00679Svporpo 
110fd5e220fSvporpo Interval<MemDGNode>
111fd5e220fSvporpo MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
112fd5e220fSvporpo                                DependencyGraph &DAG) {
11369c00679Svporpo   auto *TopMemN = getTopMemDGNode(Instrs, DAG);
114fd5e220fSvporpo   // If we couldn't find a mem node in range TopN - BotN then it's empty.
11569c00679Svporpo   if (TopMemN == nullptr)
116fd5e220fSvporpo     return {};
11769c00679Svporpo   auto *BotMemN = getBotMemDGNode(Instrs, DAG);
11869c00679Svporpo   assert(BotMemN != nullptr && "TopMemN should be null too!");
119fd5e220fSvporpo   // Now that we have the mem-dep nodes, create and return the range.
12069c00679Svporpo   return Interval<MemDGNode>(TopMemN, BotMemN);
121fd5e220fSvporpo }
122fd5e220fSvporpo 
12304a8bffdSvporpo DependencyGraph::DependencyType
12404a8bffdSvporpo DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) {
12504a8bffdSvporpo   // TODO: Perhaps compile-time improvement by skipping if neither is mem?
12604a8bffdSvporpo   if (FromI->mayWriteToMemory()) {
12704a8bffdSvporpo     if (ToI->mayReadFromMemory())
128267e8521SVasileios Porpodas       return DependencyType::ReadAfterWrite;
12904a8bffdSvporpo     if (ToI->mayWriteToMemory())
130267e8521SVasileios Porpodas       return DependencyType::WriteAfterWrite;
13104a8bffdSvporpo   } else if (FromI->mayReadFromMemory()) {
13204a8bffdSvporpo     if (ToI->mayWriteToMemory())
133267e8521SVasileios Porpodas       return DependencyType::WriteAfterRead;
13404a8bffdSvporpo   }
13504a8bffdSvporpo   if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI))
136267e8521SVasileios Porpodas     return DependencyType::Control;
13704a8bffdSvporpo   if (ToI->isTerminator())
138267e8521SVasileios Porpodas     return DependencyType::Control;
13904a8bffdSvporpo   if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) ||
14004a8bffdSvporpo       DGNode::isStackSaveOrRestoreIntrinsic(ToI))
141267e8521SVasileios Porpodas     return DependencyType::Other;
142267e8521SVasileios Porpodas   return DependencyType::None;
14304a8bffdSvporpo }
14404a8bffdSvporpo 
14504a8bffdSvporpo static bool isOrdered(Instruction *I) {
14604a8bffdSvporpo   auto IsOrdered = [](Instruction *I) {
14704a8bffdSvporpo     if (auto *LI = dyn_cast<LoadInst>(I))
14804a8bffdSvporpo       return !LI->isUnordered();
14904a8bffdSvporpo     if (auto *SI = dyn_cast<StoreInst>(I))
15004a8bffdSvporpo       return !SI->isUnordered();
15104a8bffdSvporpo     if (DGNode::isFenceLike(I))
15204a8bffdSvporpo       return true;
15304a8bffdSvporpo     return false;
15404a8bffdSvporpo   };
15504a8bffdSvporpo   bool Is = IsOrdered(I);
15604a8bffdSvporpo   assert((!Is || DGNode::isMemDepCandidate(I)) &&
15704a8bffdSvporpo          "An ordered instruction must be a MemDepCandidate!");
15804a8bffdSvporpo   return Is;
15904a8bffdSvporpo }
16004a8bffdSvporpo 
16104a8bffdSvporpo bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI,
16204a8bffdSvporpo                             DependencyType DepType) {
16304a8bffdSvporpo   std::optional<MemoryLocation> DstLocOpt =
16404a8bffdSvporpo       Utils::memoryLocationGetOrNone(DstI);
16504a8bffdSvporpo   if (!DstLocOpt)
16604a8bffdSvporpo     return true;
16704a8bffdSvporpo   // Check aliasing.
16804a8bffdSvporpo   assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) &&
16904a8bffdSvporpo          "Expected a mem instr");
17004a8bffdSvporpo   // TODO: Check AABudget
17104a8bffdSvporpo   ModRefInfo SrcModRef =
17204a8bffdSvporpo       isOrdered(SrcI)
173ee0e17a4Svporpo           ? ModRefInfo::ModRef
17404a8bffdSvporpo           : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt);
17504a8bffdSvporpo   switch (DepType) {
176267e8521SVasileios Porpodas   case DependencyType::ReadAfterWrite:
177267e8521SVasileios Porpodas   case DependencyType::WriteAfterWrite:
17804a8bffdSvporpo     return isModSet(SrcModRef);
179267e8521SVasileios Porpodas   case DependencyType::WriteAfterRead:
18004a8bffdSvporpo     return isRefSet(SrcModRef);
18104a8bffdSvporpo   default:
18204a8bffdSvporpo     llvm_unreachable("Expected only RAW, WAW and WAR!");
18304a8bffdSvporpo   }
18404a8bffdSvporpo }
18504a8bffdSvporpo 
18604a8bffdSvporpo bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
18704a8bffdSvporpo   DependencyType RoughDepType = getRoughDepType(SrcI, DstI);
18804a8bffdSvporpo   switch (RoughDepType) {
189267e8521SVasileios Porpodas   case DependencyType::ReadAfterWrite:
190267e8521SVasileios Porpodas   case DependencyType::WriteAfterWrite:
191267e8521SVasileios Porpodas   case DependencyType::WriteAfterRead:
19204a8bffdSvporpo     return alias(SrcI, DstI, RoughDepType);
193267e8521SVasileios Porpodas   case DependencyType::Control:
19404a8bffdSvporpo     // Adding actual dep edges from PHIs/to terminator would just create too
19504a8bffdSvporpo     // many edges, which would be bad for compile-time.
19604a8bffdSvporpo     // So we ignore them in the DAG formation but handle them in the
19704a8bffdSvporpo     // scheduler, while sorting the ready list.
19804a8bffdSvporpo     return false;
199267e8521SVasileios Porpodas   case DependencyType::Other:
20004a8bffdSvporpo     return true;
201267e8521SVasileios Porpodas   case DependencyType::None:
20204a8bffdSvporpo     return false;
20304a8bffdSvporpo   }
20400c1c589SSimon Pilgrim   llvm_unreachable("Unknown DependencyType enum");
20504a8bffdSvporpo }
20604a8bffdSvporpo 
207a4916d20Svporpo void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
20804a8bffdSvporpo                                      const Interval<MemDGNode> &SrcScanRange) {
20904a8bffdSvporpo   assert(isa<MemDGNode>(DstN) &&
21004a8bffdSvporpo          "DstN is the mem dep destination, so it must be mem");
21104a8bffdSvporpo   Instruction *DstI = DstN.getInstruction();
21204a8bffdSvporpo   // Walk up the instruction chain from ScanRange bottom to top, looking for
21304a8bffdSvporpo   // memory instrs that may alias.
21404a8bffdSvporpo   for (MemDGNode &SrcN : reverse(SrcScanRange)) {
21504a8bffdSvporpo     Instruction *SrcI = SrcN.getInstruction();
21604a8bffdSvporpo     if (hasDep(SrcI, DstI))
21704a8bffdSvporpo       DstN.addMemPred(&SrcN);
21804a8bffdSvporpo   }
21904a8bffdSvporpo }
22004a8bffdSvporpo 
221fc08ad66Svporpo void DependencyGraph::setDefUseUnscheduledSuccs(
222fc08ad66Svporpo     const Interval<Instruction> &NewInterval) {
223fc08ad66Svporpo   // +---+
224fc08ad66Svporpo   // |   |  Def
225fc08ad66Svporpo   // |   |   |
226fc08ad66Svporpo   // |   |   v
227fc08ad66Svporpo   // |   |  Use
228fc08ad66Svporpo   // +---+
229fc08ad66Svporpo   // Set the intra-interval counters in NewInterval.
230fc08ad66Svporpo   for (Instruction &I : NewInterval) {
231fc08ad66Svporpo     for (Value *Op : I.operands()) {
232fc08ad66Svporpo       auto *OpI = dyn_cast<Instruction>(Op);
233fc08ad66Svporpo       if (OpI == nullptr)
234fc08ad66Svporpo         continue;
235c7053ac2Svporpo       // TODO: For now don't cross BBs.
236c7053ac2Svporpo       if (OpI->getParent() != I.getParent())
237c7053ac2Svporpo         continue;
238fc08ad66Svporpo       if (!NewInterval.contains(OpI))
239fc08ad66Svporpo         continue;
240fc08ad66Svporpo       auto *OpN = getNode(OpI);
241fc08ad66Svporpo       if (OpN == nullptr)
242fc08ad66Svporpo         continue;
243fc08ad66Svporpo       ++OpN->UnscheduledSuccs;
244fc08ad66Svporpo     }
245fc08ad66Svporpo   }
246fc08ad66Svporpo 
247fc08ad66Svporpo   // Now handle the cross-interval edges.
248fc08ad66Svporpo   bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval);
249fc08ad66Svporpo   const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
250fc08ad66Svporpo   const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
251fc08ad66Svporpo   // +---+
252fc08ad66Svporpo   // |Top|
253fc08ad66Svporpo   // |   |  Def
254fc08ad66Svporpo   // +---+   |
255fc08ad66Svporpo   // |   |   v
256fc08ad66Svporpo   // |Bot|  Use
257fc08ad66Svporpo   // |   |
258fc08ad66Svporpo   // +---+
259fc08ad66Svporpo   // Walk over all instructions in "BotInterval" and update the counter
260fc08ad66Svporpo   // of operands that are in "TopInterval".
261fc08ad66Svporpo   for (Instruction &BotI : BotInterval) {
2621d09925bSvporpo     auto *BotN = getNode(&BotI);
2631d09925bSvporpo     // Skip scheduled nodes.
2641d09925bSvporpo     if (BotN->scheduled())
2651d09925bSvporpo       continue;
266fc08ad66Svporpo     for (Value *Op : BotI.operands()) {
267fc08ad66Svporpo       auto *OpI = dyn_cast<Instruction>(Op);
268fc08ad66Svporpo       if (OpI == nullptr)
269fc08ad66Svporpo         continue;
270fc08ad66Svporpo       auto *OpN = getNode(OpI);
271fc08ad66Svporpo       if (OpN == nullptr)
272fc08ad66Svporpo         continue;
273*1c4341d1SVasileios Porpodas       if (!TopInterval.contains(OpI))
274*1c4341d1SVasileios Porpodas         continue;
275fc08ad66Svporpo       ++OpN->UnscheduledSuccs;
276fc08ad66Svporpo     }
277fc08ad66Svporpo   }
278fc08ad66Svporpo }
279fc08ad66Svporpo 
280e8dd95e9Svporpo void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
281e8dd95e9Svporpo   // Create Nodes only for the new sections of the DAG.
282e8dd95e9Svporpo   DGNode *LastN = getOrCreateNode(NewInterval.top());
283fd5e220fSvporpo   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
284e8dd95e9Svporpo   for (Instruction &I : drop_begin(NewInterval)) {
28504a8bffdSvporpo     auto *N = getOrCreateNode(&I);
286fd5e220fSvporpo     // Build the Mem node chain.
287fd5e220fSvporpo     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
288fd5e220fSvporpo       MemN->setPrevNode(LastMemN);
289fd5e220fSvporpo       LastMemN = MemN;
290fd5e220fSvporpo     }
291318d2f5eSvporpo   }
292e8dd95e9Svporpo   // Link new MemDGNode chain with the old one, if any.
293e8dd95e9Svporpo   if (!DAGInterval.empty()) {
29431b85c6eSvporpo     bool NewIsAbove = NewInterval.comesBefore(DAGInterval);
295e8dd95e9Svporpo     const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
296e8dd95e9Svporpo     const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
297e8dd95e9Svporpo     MemDGNode *LinkTopN =
298e8dd95e9Svporpo         MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
299e8dd95e9Svporpo     MemDGNode *LinkBotN =
300e8dd95e9Svporpo         MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
3011d09925bSvporpo     assert((LinkTopN == nullptr || LinkBotN == nullptr ||
3021d09925bSvporpo             LinkTopN->comesBefore(LinkBotN)) &&
3031d09925bSvporpo            "Wrong order!");
304e8dd95e9Svporpo     if (LinkTopN != nullptr && LinkBotN != nullptr) {
305e8dd95e9Svporpo       LinkTopN->setNextNode(LinkBotN);
306e8dd95e9Svporpo     }
307e8dd95e9Svporpo #ifndef NDEBUG
308e8dd95e9Svporpo     // TODO: Remove this once we've done enough testing.
309e8dd95e9Svporpo     // Check that the chain is well formed.
310e8dd95e9Svporpo     auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
311e8dd95e9Svporpo     MemDGNode *ChainTopN =
312e8dd95e9Svporpo         MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
313e8dd95e9Svporpo     MemDGNode *ChainBotN =
314e8dd95e9Svporpo         MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
315e8dd95e9Svporpo     if (ChainTopN != nullptr && ChainBotN != nullptr) {
316e8dd95e9Svporpo       for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
317e8dd95e9Svporpo            LastN = N, N = N->getNextNode()) {
318e8dd95e9Svporpo         assert(N == LastN->getNextNode() && "Bad chain!");
319e8dd95e9Svporpo         assert(N->getPrevNode() == LastN && "Bad chain!");
320e8dd95e9Svporpo       }
321e8dd95e9Svporpo     }
322e8dd95e9Svporpo #endif // NDEBUG
323e8dd95e9Svporpo   }
324fc08ad66Svporpo 
325fc08ad66Svporpo   setDefUseUnscheduledSuccs(NewInterval);
326e8dd95e9Svporpo }
327e8dd95e9Svporpo 
328b41987beSvporpo MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
329b41987beSvporpo                                                MemDGNode *SkipN) const {
330eeb55d3aSvporpo   auto *I = N->getInstruction();
331eeb55d3aSvporpo   for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
332eeb55d3aSvporpo        PrevI = PrevI->getPrevNode()) {
333eeb55d3aSvporpo     auto *PrevN = getNodeOrNull(PrevI);
334eeb55d3aSvporpo     if (PrevN == nullptr)
335eeb55d3aSvporpo       return nullptr;
336b41987beSvporpo     auto *PrevMemN = dyn_cast<MemDGNode>(PrevN);
337b41987beSvporpo     if (PrevMemN != nullptr && PrevMemN != SkipN)
338eeb55d3aSvporpo       return PrevMemN;
339eeb55d3aSvporpo   }
340eeb55d3aSvporpo   return nullptr;
341eeb55d3aSvporpo }
342eeb55d3aSvporpo 
343b41987beSvporpo MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
344b41987beSvporpo                                               MemDGNode *SkipN) const {
345eeb55d3aSvporpo   auto *I = N->getInstruction();
346eeb55d3aSvporpo   for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
347eeb55d3aSvporpo        NextI = NextI->getNextNode()) {
348eeb55d3aSvporpo     auto *NextN = getNodeOrNull(NextI);
349eeb55d3aSvporpo     if (NextN == nullptr)
350eeb55d3aSvporpo       return nullptr;
351b41987beSvporpo     auto *NextMemN = dyn_cast<MemDGNode>(NextN);
352b41987beSvporpo     if (NextMemN != nullptr && NextMemN != SkipN)
353eeb55d3aSvporpo       return NextMemN;
354eeb55d3aSvporpo   }
355eeb55d3aSvporpo   return nullptr;
356eeb55d3aSvporpo }
357eeb55d3aSvporpo 
358eeb55d3aSvporpo void DependencyGraph::notifyCreateInstr(Instruction *I) {
359eeb55d3aSvporpo   auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
360eeb55d3aSvporpo   // TODO: Update the dependencies for the new node.
361eeb55d3aSvporpo 
362eeb55d3aSvporpo   // Update the MemDGNode chain if this is a memory node.
363eeb55d3aSvporpo   if (MemN != nullptr) {
364eeb55d3aSvporpo     if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) {
365eeb55d3aSvporpo       PrevMemN->NextMemN = MemN;
366eeb55d3aSvporpo       MemN->PrevMemN = PrevMemN;
367eeb55d3aSvporpo     }
368eeb55d3aSvporpo     if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) {
369eeb55d3aSvporpo       NextMemN->PrevMemN = MemN;
370eeb55d3aSvporpo       MemN->NextMemN = NextMemN;
371eeb55d3aSvporpo     }
372eeb55d3aSvporpo   }
373eeb55d3aSvporpo }
374eeb55d3aSvporpo 
3757a38445eSvporpo void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
37640893149SVasileios Porpodas   // NOTE: This function runs before `I` moves to its new destination.
3777a38445eSvporpo   BasicBlock *BB = To.getNodeParent();
37840893149SVasileios Porpodas   assert(!(To != BB->end() && &*To == I->getNextNode()) &&
37940893149SVasileios Porpodas          !(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
38040893149SVasileios Porpodas          "Should not have been called if destination is same as origin.");
3817a38445eSvporpo 
382b41987beSvporpo   // TODO: We can only handle fully internal movements within DAGInterval or at
383b41987beSvporpo   // the borders, i.e., right before the top or right after the bottom.
384b41987beSvporpo   assert(To.getNodeParent() == I->getParent() &&
385b41987beSvporpo          "TODO: We don't support movement across BBs!");
386b41987beSvporpo   assert(
387b41987beSvporpo       (To == std::next(DAGInterval.bottom()->getIterator()) ||
388b41987beSvporpo        (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
389b41987beSvporpo        (To != BB->end() && DAGInterval.contains(&*To))) &&
390b41987beSvporpo       "TODO: To should be either within the DAGInterval or right "
391b41987beSvporpo       "before/after it.");
392b41987beSvporpo 
393b41987beSvporpo   // Make a copy of the DAGInterval before we update it.
394b41987beSvporpo   auto OrigDAGInterval = DAGInterval;
395b41987beSvporpo 
3967a38445eSvporpo   // Maintain the DAGInterval.
3977a38445eSvporpo   DAGInterval.notifyMoveInstr(I, To);
3987a38445eSvporpo 
3997a38445eSvporpo   // TODO: Perhaps check if this is legal by checking the dependencies?
4007a38445eSvporpo 
4017a38445eSvporpo   // Update the MemDGNode chain to reflect the instr movement if necessary.
4027a38445eSvporpo   DGNode *N = getNodeOrNull(I);
4037a38445eSvporpo   if (N == nullptr)
4047a38445eSvporpo     return;
4057a38445eSvporpo   MemDGNode *MemN = dyn_cast<MemDGNode>(N);
4067a38445eSvporpo   if (MemN == nullptr)
4077a38445eSvporpo     return;
408b41987beSvporpo 
409b41987beSvporpo   // First safely detach it from the existing chain.
4107a38445eSvporpo   MemN->detachFromChain();
411b41987beSvporpo 
4127a38445eSvporpo   // Now insert it back into the chain at the new location.
413b41987beSvporpo   //
414b41987beSvporpo   // We won't always have a DGNode to insert before it. If `To` is BB->end() or
415b41987beSvporpo   // if it points to an instr after DAGInterval.bottom() then we will have to
416b41987beSvporpo   // find a node to insert *after*.
417b41987beSvporpo   //
418b41987beSvporpo   // BB:                              BB:
419b41987beSvporpo   //  I1                               I1 ^
420b41987beSvporpo   //  I2                               I2 | DAGInteval [I1 to I3]
421b41987beSvporpo   //  I3                               I3 V
422b41987beSvporpo   //  I4                               I4   <- `To` == right after DAGInterval
423b41987beSvporpo   //    <- `To` == BB->end()
424b41987beSvporpo   //
425b41987beSvporpo   if (To == BB->end() ||
426b41987beSvporpo       To == std::next(OrigDAGInterval.bottom()->getIterator())) {
427b41987beSvporpo     // If we don't have a node to insert before, find a node to insert after and
428b41987beSvporpo     // update the chain.
429b41987beSvporpo     DGNode *InsertAfterN = getNode(&*std::prev(To));
430b41987beSvporpo     MemN->setPrevNode(
431b41987beSvporpo         getMemDGNodeBefore(InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
4327a38445eSvporpo   } else {
433b41987beSvporpo     // We have a node to insert before, so update the chain.
434b41987beSvporpo     DGNode *BeforeToN = getNode(&*To);
435b41987beSvporpo     MemN->setPrevNode(
436b41987beSvporpo         getMemDGNodeBefore(BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
437b41987beSvporpo     MemN->setNextNode(
438b41987beSvporpo         getMemDGNodeAfter(BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
4397a38445eSvporpo   }
4407a38445eSvporpo }
4417a38445eSvporpo 
442cafb6b99Svporpo void DependencyGraph::notifyEraseInstr(Instruction *I) {
443cafb6b99Svporpo   // Update the MemDGNode chain if this is a memory node.
444cafb6b99Svporpo   if (auto *MemN = dyn_cast_or_null<MemDGNode>(getNodeOrNull(I))) {
445cafb6b99Svporpo     auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false);
446cafb6b99Svporpo     auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false);
447cafb6b99Svporpo     if (PrevMemN != nullptr)
448cafb6b99Svporpo       PrevMemN->NextMemN = NextMemN;
449cafb6b99Svporpo     if (NextMemN != nullptr)
450cafb6b99Svporpo       NextMemN->PrevMemN = PrevMemN;
451cafb6b99Svporpo   }
452cafb6b99Svporpo 
453cafb6b99Svporpo   InstrToNodeMap.erase(I);
454cafb6b99Svporpo 
455cafb6b99Svporpo   // TODO: Update the dependencies.
456cafb6b99Svporpo }
457cafb6b99Svporpo 
458e8dd95e9Svporpo Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
459e8dd95e9Svporpo   if (Instrs.empty())
460e8dd95e9Svporpo     return {};
461e8dd95e9Svporpo 
462e8dd95e9Svporpo   Interval<Instruction> InstrsInterval(Instrs);
463e8dd95e9Svporpo   Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
464e8dd95e9Svporpo   auto NewInterval = Union.getSingleDiff(DAGInterval);
465e8dd95e9Svporpo   if (NewInterval.empty())
466e8dd95e9Svporpo     return {};
467e8dd95e9Svporpo 
468e8dd95e9Svporpo   createNewNodes(NewInterval);
469e8dd95e9Svporpo 
47004a8bffdSvporpo   // Create the dependencies.
471e8dd95e9Svporpo   //
47208bfc9b0Svporpo   // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval.
47308bfc9b0Svporpo   // +---+       -             -
47408bfc9b0Svporpo   // |   | SrcN  |             |
47508bfc9b0Svporpo   // |   |  |    | SrcRange    |
47608bfc9b0Svporpo   // |New|  v    |             | DstRange
47708bfc9b0Svporpo   // |   | DstN  -             |
47808bfc9b0Svporpo   // |   |                     |
47908bfc9b0Svporpo   // +---+                     -
48008bfc9b0Svporpo   // We are scanning for deps with destination in NewInterval and sources in
48108bfc9b0Svporpo   // NewInterval until DstN, for each DstN.
48208bfc9b0Svporpo   auto FullScan = [this](const Interval<Instruction> Intvl) {
48308bfc9b0Svporpo     auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this);
484ee0e17a4Svporpo     if (!DstRange.empty()) {
48504a8bffdSvporpo       for (MemDGNode &DstN : drop_begin(DstRange)) {
48604a8bffdSvporpo         auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
48704a8bffdSvporpo         scanAndAddDeps(DstN, SrcRange);
48804a8bffdSvporpo       }
489ee0e17a4Svporpo     }
49008bfc9b0Svporpo   };
49108bfc9b0Svporpo   if (DAGInterval.empty()) {
49208bfc9b0Svporpo     assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
49308bfc9b0Svporpo     FullScan(NewInterval);
494e8dd95e9Svporpo   }
495e8dd95e9Svporpo   // 2. The new section is below the old section.
49608bfc9b0Svporpo   // +---+       -
49708bfc9b0Svporpo   // |   |       |
49808bfc9b0Svporpo   // |Old| SrcN  |
49908bfc9b0Svporpo   // |   |  |    |
50008bfc9b0Svporpo   // +---+  |    | SrcRange
50108bfc9b0Svporpo   // +---+  |    |             -
50208bfc9b0Svporpo   // |   |  |    |             |
50308bfc9b0Svporpo   // |New|  v    |             | DstRange
50408bfc9b0Svporpo   // |   | DstN  -             |
50508bfc9b0Svporpo   // |   |                     |
50608bfc9b0Svporpo   // +---+                     -
50708bfc9b0Svporpo   // We are scanning for deps with destination in NewInterval because the deps
50808bfc9b0Svporpo   // in DAGInterval have already been computed. We consider sources in the whole
50908bfc9b0Svporpo   // range including both NewInterval and DAGInterval until DstN, for each DstN.
510e8dd95e9Svporpo   else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
511e8dd95e9Svporpo     auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
512e8dd95e9Svporpo     auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
513e8dd95e9Svporpo         DAGInterval.getUnionInterval(NewInterval), *this);
514e8dd95e9Svporpo     for (MemDGNode &DstN : DstRange) {
515e8dd95e9Svporpo       auto SrcRange =
516e8dd95e9Svporpo           Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
517e8dd95e9Svporpo       scanAndAddDeps(DstN, SrcRange);
518e8dd95e9Svporpo     }
519e8dd95e9Svporpo   }
520e8dd95e9Svporpo   // 3. The new section is above the old section.
521e8dd95e9Svporpo   else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
52208bfc9b0Svporpo     // +---+       -             -
52308bfc9b0Svporpo     // |   | SrcN  |             |
52408bfc9b0Svporpo     // |New|  |    | SrcRange    | DstRange
52508bfc9b0Svporpo     // |   |  v    |             |
52608bfc9b0Svporpo     // |   | DstN  -             |
52708bfc9b0Svporpo     // |   |                     |
52808bfc9b0Svporpo     // +---+                     -
52908bfc9b0Svporpo     // +---+
53008bfc9b0Svporpo     // |Old|
53108bfc9b0Svporpo     // |   |
53208bfc9b0Svporpo     // +---+
53308bfc9b0Svporpo     // When scanning for deps with destination in NewInterval we need to fully
53408bfc9b0Svporpo     // scan the interval. This is the same as the scanning for a new DAG.
53508bfc9b0Svporpo     FullScan(NewInterval);
53608bfc9b0Svporpo 
53708bfc9b0Svporpo     // +---+       -
53808bfc9b0Svporpo     // |   |       |
53908bfc9b0Svporpo     // |New| SrcN  | SrcRange
54008bfc9b0Svporpo     // |   |  |    |
54108bfc9b0Svporpo     // |   |  |    |
54208bfc9b0Svporpo     // |   |  |    |
54308bfc9b0Svporpo     // +---+  |    -
54408bfc9b0Svporpo     // +---+  |                  -
54508bfc9b0Svporpo     // |Old|  v                  | DstRange
54608bfc9b0Svporpo     // |   | DstN                |
54708bfc9b0Svporpo     // +---+                     -
54808bfc9b0Svporpo     // When scanning for deps with destination in DAGInterval we need to
54908bfc9b0Svporpo     // consider sources from the NewInterval only, because all intra-DAGInterval
55008bfc9b0Svporpo     // dependencies have already been created.
55108bfc9b0Svporpo     auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this);
55208bfc9b0Svporpo     auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
55308bfc9b0Svporpo     for (MemDGNode &DstN : DstRangeOld)
554e8dd95e9Svporpo       scanAndAddDeps(DstN, SrcRange);
555e8dd95e9Svporpo   } else {
556e8dd95e9Svporpo     llvm_unreachable("We don't expect extending in both directions!");
557e8dd95e9Svporpo   }
55804a8bffdSvporpo 
559e8dd95e9Svporpo   DAGInterval = Union;
560e8dd95e9Svporpo   return NewInterval;
561318d2f5eSvporpo }
562318d2f5eSvporpo 
563318d2f5eSvporpo #ifndef NDEBUG
564318d2f5eSvporpo void DependencyGraph::print(raw_ostream &OS) const {
565318d2f5eSvporpo   // InstrToNodeMap is unordered so we need to create an ordered vector.
566318d2f5eSvporpo   SmallVector<DGNode *> Nodes;
567318d2f5eSvporpo   Nodes.reserve(InstrToNodeMap.size());
568318d2f5eSvporpo   for (const auto &Pair : InstrToNodeMap)
569318d2f5eSvporpo     Nodes.push_back(Pair.second.get());
570318d2f5eSvporpo   // Sort them based on which one comes first in the BB.
571318d2f5eSvporpo   sort(Nodes, [](DGNode *N1, DGNode *N2) {
572318d2f5eSvporpo     return N1->getInstruction()->comesBefore(N2->getInstruction());
573318d2f5eSvporpo   });
574318d2f5eSvporpo   for (auto *N : Nodes)
575318d2f5eSvporpo     N->print(OS, /*PrintDeps=*/true);
576318d2f5eSvporpo }
577318d2f5eSvporpo 
578318d2f5eSvporpo void DependencyGraph::dump() const {
579318d2f5eSvporpo   print(dbgs());
580318d2f5eSvporpo   dbgs() << "\n";
581318d2f5eSvporpo }
582318d2f5eSvporpo #endif // NDEBUG
58304a8bffdSvporpo 
58404a8bffdSvporpo } // namespace llvm::sandboxir
585