xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (revision 1d09925b4a6fd4af0120825132be23be12fb03d6)
1 //===- DependencyGraph.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/SandboxIR/Instruction.h"
12 #include "llvm/SandboxIR/Utils.h"
13 
14 namespace llvm::sandboxir {
15 
16 PredIterator::value_type PredIterator::operator*() {
17   // If it's a DGNode then we dereference the operand iterator.
18   if (!isa<MemDGNode>(N)) {
19     assert(OpIt != OpItE && "Can't dereference end iterator!");
20     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
21   }
22   // It's a MemDGNode, so we check if we return either the use-def operand,
23   // or a mem predecessor.
24   if (OpIt != OpItE)
25     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
26   // It's a MemDGNode with OpIt == end, so we need to use MemIt.
27   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
28          "Cant' dereference end iterator!");
29   return *MemIt;
30 }
31 
32 PredIterator &PredIterator::operator++() {
33   // If it's a DGNode then we increment the use-def iterator.
34   if (!isa<MemDGNode>(N)) {
35     assert(OpIt != OpItE && "Already at end!");
36     ++OpIt;
37     // Skip operands that are not instructions.
38     OpIt = skipNonInstr(OpIt, OpItE);
39     return *this;
40   }
41   // It's a MemDGNode, so if we are not at the end of the use-def iterator we
42   // need to first increment that.
43   if (OpIt != OpItE) {
44     ++OpIt;
45     // Skip operands that are not instructions.
46     OpIt = skipNonInstr(OpIt, OpItE);
47     return *this;
48   }
49   // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
50   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
51   ++MemIt;
52   return *this;
53 }
54 
55 bool PredIterator::operator==(const PredIterator &Other) const {
56   assert(DAG == Other.DAG && "Iterators of different DAGs!");
57   assert(N == Other.N && "Iterators of different nodes!");
58   return OpIt == Other.OpIt && MemIt == Other.MemIt;
59 }
60 
61 #ifndef NDEBUG
62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
63   OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n";
64 }
65 void DGNode::dump() const { print(dbgs()); }
66 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
67   DGNode::print(OS, false);
68   if (PrintDeps) {
69     // Print memory preds.
70     static constexpr const unsigned Indent = 4;
71     for (auto *Pred : MemPreds)
72       OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n";
73   }
74 }
75 #endif // NDEBUG
76 
77 MemDGNode *
78 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
79                                           const DependencyGraph &DAG) {
80   Instruction *I = Intvl.top();
81   Instruction *BeforeI = Intvl.bottom();
82   // Walk down the chain looking for a mem-dep candidate instruction.
83   while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
84     I = I->getNextNode();
85   if (!DGNode::isMemDepNodeCandidate(I))
86     return nullptr;
87   return cast<MemDGNode>(DAG.getNode(I));
88 }
89 
90 MemDGNode *
91 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
92                                           const DependencyGraph &DAG) {
93   Instruction *I = Intvl.bottom();
94   Instruction *AfterI = Intvl.top();
95   // Walk up the chain looking for a mem-dep candidate instruction.
96   while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
97     I = I->getPrevNode();
98   if (!DGNode::isMemDepNodeCandidate(I))
99     return nullptr;
100   return cast<MemDGNode>(DAG.getNode(I));
101 }
102 
103 Interval<MemDGNode>
104 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
105                                DependencyGraph &DAG) {
106   auto *TopMemN = getTopMemDGNode(Instrs, DAG);
107   // If we couldn't find a mem node in range TopN - BotN then it's empty.
108   if (TopMemN == nullptr)
109     return {};
110   auto *BotMemN = getBotMemDGNode(Instrs, DAG);
111   assert(BotMemN != nullptr && "TopMemN should be null too!");
112   // Now that we have the mem-dep nodes, create and return the range.
113   return Interval<MemDGNode>(TopMemN, BotMemN);
114 }
115 
116 DependencyGraph::DependencyType
117 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) {
118   // TODO: Perhaps compile-time improvement by skipping if neither is mem?
119   if (FromI->mayWriteToMemory()) {
120     if (ToI->mayReadFromMemory())
121       return DependencyType::ReadAfterWrite;
122     if (ToI->mayWriteToMemory())
123       return DependencyType::WriteAfterWrite;
124   } else if (FromI->mayReadFromMemory()) {
125     if (ToI->mayWriteToMemory())
126       return DependencyType::WriteAfterRead;
127   }
128   if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI))
129     return DependencyType::Control;
130   if (ToI->isTerminator())
131     return DependencyType::Control;
132   if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) ||
133       DGNode::isStackSaveOrRestoreIntrinsic(ToI))
134     return DependencyType::Other;
135   return DependencyType::None;
136 }
137 
138 static bool isOrdered(Instruction *I) {
139   auto IsOrdered = [](Instruction *I) {
140     if (auto *LI = dyn_cast<LoadInst>(I))
141       return !LI->isUnordered();
142     if (auto *SI = dyn_cast<StoreInst>(I))
143       return !SI->isUnordered();
144     if (DGNode::isFenceLike(I))
145       return true;
146     return false;
147   };
148   bool Is = IsOrdered(I);
149   assert((!Is || DGNode::isMemDepCandidate(I)) &&
150          "An ordered instruction must be a MemDepCandidate!");
151   return Is;
152 }
153 
154 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI,
155                             DependencyType DepType) {
156   std::optional<MemoryLocation> DstLocOpt =
157       Utils::memoryLocationGetOrNone(DstI);
158   if (!DstLocOpt)
159     return true;
160   // Check aliasing.
161   assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) &&
162          "Expected a mem instr");
163   // TODO: Check AABudget
164   ModRefInfo SrcModRef =
165       isOrdered(SrcI)
166           ? ModRefInfo::ModRef
167           : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt);
168   switch (DepType) {
169   case DependencyType::ReadAfterWrite:
170   case DependencyType::WriteAfterWrite:
171     return isModSet(SrcModRef);
172   case DependencyType::WriteAfterRead:
173     return isRefSet(SrcModRef);
174   default:
175     llvm_unreachable("Expected only RAW, WAW and WAR!");
176   }
177 }
178 
179 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
180   DependencyType RoughDepType = getRoughDepType(SrcI, DstI);
181   switch (RoughDepType) {
182   case DependencyType::ReadAfterWrite:
183   case DependencyType::WriteAfterWrite:
184   case DependencyType::WriteAfterRead:
185     return alias(SrcI, DstI, RoughDepType);
186   case DependencyType::Control:
187     // Adding actual dep edges from PHIs/to terminator would just create too
188     // many edges, which would be bad for compile-time.
189     // So we ignore them in the DAG formation but handle them in the
190     // scheduler, while sorting the ready list.
191     return false;
192   case DependencyType::Other:
193     return true;
194   case DependencyType::None:
195     return false;
196   }
197   llvm_unreachable("Unknown DependencyType enum");
198 }
199 
200 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
201                                      const Interval<MemDGNode> &SrcScanRange) {
202   assert(isa<MemDGNode>(DstN) &&
203          "DstN is the mem dep destination, so it must be mem");
204   Instruction *DstI = DstN.getInstruction();
205   // Walk up the instruction chain from ScanRange bottom to top, looking for
206   // memory instrs that may alias.
207   for (MemDGNode &SrcN : reverse(SrcScanRange)) {
208     Instruction *SrcI = SrcN.getInstruction();
209     if (hasDep(SrcI, DstI))
210       DstN.addMemPred(&SrcN);
211   }
212 }
213 
214 void DependencyGraph::setDefUseUnscheduledSuccs(
215     const Interval<Instruction> &NewInterval) {
216   // +---+
217   // |   |  Def
218   // |   |   |
219   // |   |   v
220   // |   |  Use
221   // +---+
222   // Set the intra-interval counters in NewInterval.
223   for (Instruction &I : NewInterval) {
224     for (Value *Op : I.operands()) {
225       auto *OpI = dyn_cast<Instruction>(Op);
226       if (OpI == nullptr)
227         continue;
228       if (!NewInterval.contains(OpI))
229         continue;
230       auto *OpN = getNode(OpI);
231       if (OpN == nullptr)
232         continue;
233       ++OpN->UnscheduledSuccs;
234     }
235   }
236 
237   // Now handle the cross-interval edges.
238   bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval);
239   const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
240   const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
241   // +---+
242   // |Top|
243   // |   |  Def
244   // +---+   |
245   // |   |   v
246   // |Bot|  Use
247   // |   |
248   // +---+
249   // Walk over all instructions in "BotInterval" and update the counter
250   // of operands that are in "TopInterval".
251   for (Instruction &BotI : BotInterval) {
252     auto *BotN = getNode(&BotI);
253     // Skip scheduled nodes.
254     if (BotN->scheduled())
255       continue;
256     for (Value *Op : BotI.operands()) {
257       auto *OpI = dyn_cast<Instruction>(Op);
258       if (OpI == nullptr)
259         continue;
260       if (!TopInterval.contains(OpI))
261         continue;
262       auto *OpN = getNode(OpI);
263       if (OpN == nullptr)
264         continue;
265       ++OpN->UnscheduledSuccs;
266     }
267   }
268 }
269 
270 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
271   // Create Nodes only for the new sections of the DAG.
272   DGNode *LastN = getOrCreateNode(NewInterval.top());
273   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
274   for (Instruction &I : drop_begin(NewInterval)) {
275     auto *N = getOrCreateNode(&I);
276     // Build the Mem node chain.
277     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
278       MemN->setPrevNode(LastMemN);
279       if (LastMemN != nullptr)
280         LastMemN->setNextNode(MemN);
281       LastMemN = MemN;
282     }
283   }
284   // Link new MemDGNode chain with the old one, if any.
285   if (!DAGInterval.empty()) {
286     bool NewIsAbove = NewInterval.comesBefore(DAGInterval);
287     const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
288     const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
289     MemDGNode *LinkTopN =
290         MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
291     MemDGNode *LinkBotN =
292         MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
293     assert((LinkTopN == nullptr || LinkBotN == nullptr ||
294             LinkTopN->comesBefore(LinkBotN)) &&
295            "Wrong order!");
296     if (LinkTopN != nullptr && LinkBotN != nullptr) {
297       LinkTopN->setNextNode(LinkBotN);
298       LinkBotN->setPrevNode(LinkTopN);
299     }
300 #ifndef NDEBUG
301     // TODO: Remove this once we've done enough testing.
302     // Check that the chain is well formed.
303     auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
304     MemDGNode *ChainTopN =
305         MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
306     MemDGNode *ChainBotN =
307         MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
308     if (ChainTopN != nullptr && ChainBotN != nullptr) {
309       for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
310            LastN = N, N = N->getNextNode()) {
311         assert(N == LastN->getNextNode() && "Bad chain!");
312         assert(N->getPrevNode() == LastN && "Bad chain!");
313       }
314     }
315 #endif // NDEBUG
316   }
317 
318   setDefUseUnscheduledSuccs(NewInterval);
319 }
320 
321 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
322   if (Instrs.empty())
323     return {};
324 
325   Interval<Instruction> InstrsInterval(Instrs);
326   Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
327   auto NewInterval = Union.getSingleDiff(DAGInterval);
328   if (NewInterval.empty())
329     return {};
330 
331   createNewNodes(NewInterval);
332 
333   // Create the dependencies.
334   //
335   // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval.
336   // +---+       -             -
337   // |   | SrcN  |             |
338   // |   |  |    | SrcRange    |
339   // |New|  v    |             | DstRange
340   // |   | DstN  -             |
341   // |   |                     |
342   // +---+                     -
343   // We are scanning for deps with destination in NewInterval and sources in
344   // NewInterval until DstN, for each DstN.
345   auto FullScan = [this](const Interval<Instruction> Intvl) {
346     auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this);
347     if (!DstRange.empty()) {
348       for (MemDGNode &DstN : drop_begin(DstRange)) {
349         auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
350         scanAndAddDeps(DstN, SrcRange);
351       }
352     }
353   };
354   if (DAGInterval.empty()) {
355     assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
356     FullScan(NewInterval);
357   }
358   // 2. The new section is below the old section.
359   // +---+       -
360   // |   |       |
361   // |Old| SrcN  |
362   // |   |  |    |
363   // +---+  |    | SrcRange
364   // +---+  |    |             -
365   // |   |  |    |             |
366   // |New|  v    |             | DstRange
367   // |   | DstN  -             |
368   // |   |                     |
369   // +---+                     -
370   // We are scanning for deps with destination in NewInterval because the deps
371   // in DAGInterval have already been computed. We consider sources in the whole
372   // range including both NewInterval and DAGInterval until DstN, for each DstN.
373   else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
374     auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
375     auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
376         DAGInterval.getUnionInterval(NewInterval), *this);
377     for (MemDGNode &DstN : DstRange) {
378       auto SrcRange =
379           Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
380       scanAndAddDeps(DstN, SrcRange);
381     }
382   }
383   // 3. The new section is above the old section.
384   else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
385     // +---+       -             -
386     // |   | SrcN  |             |
387     // |New|  |    | SrcRange    | DstRange
388     // |   |  v    |             |
389     // |   | DstN  -             |
390     // |   |                     |
391     // +---+                     -
392     // +---+
393     // |Old|
394     // |   |
395     // +---+
396     // When scanning for deps with destination in NewInterval we need to fully
397     // scan the interval. This is the same as the scanning for a new DAG.
398     FullScan(NewInterval);
399 
400     // +---+       -
401     // |   |       |
402     // |New| SrcN  | SrcRange
403     // |   |  |    |
404     // |   |  |    |
405     // |   |  |    |
406     // +---+  |    -
407     // +---+  |                  -
408     // |Old|  v                  | DstRange
409     // |   | DstN                |
410     // +---+                     -
411     // When scanning for deps with destination in DAGInterval we need to
412     // consider sources from the NewInterval only, because all intra-DAGInterval
413     // dependencies have already been created.
414     auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this);
415     auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
416     for (MemDGNode &DstN : DstRangeOld)
417       scanAndAddDeps(DstN, SrcRange);
418   } else {
419     llvm_unreachable("We don't expect extending in both directions!");
420   }
421 
422   DAGInterval = Union;
423   return NewInterval;
424 }
425 
426 #ifndef NDEBUG
427 void DependencyGraph::print(raw_ostream &OS) const {
428   // InstrToNodeMap is unordered so we need to create an ordered vector.
429   SmallVector<DGNode *> Nodes;
430   Nodes.reserve(InstrToNodeMap.size());
431   for (const auto &Pair : InstrToNodeMap)
432     Nodes.push_back(Pair.second.get());
433   // Sort them based on which one comes first in the BB.
434   sort(Nodes, [](DGNode *N1, DGNode *N2) {
435     return N1->getInstruction()->comesBefore(N2->getInstruction());
436   });
437   for (auto *N : Nodes)
438     N->print(OS, /*PrintDeps=*/true);
439 }
440 
441 void DependencyGraph::dump() const {
442   print(dbgs());
443   dbgs() << "\n";
444 }
445 #endif // NDEBUG
446 
447 } // namespace llvm::sandboxir
448