xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (revision 08bfc9b0aeee798052465246d8f7eb01a0eea2db)
1 //===- DependencyGraph.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/SandboxIR/Instruction.h"
12 #include "llvm/SandboxIR/Utils.h"
13 
14 namespace llvm::sandboxir {
15 
16 PredIterator::value_type PredIterator::operator*() {
17   // If it's a DGNode then we dereference the operand iterator.
18   if (!isa<MemDGNode>(N)) {
19     assert(OpIt != OpItE && "Can't dereference end iterator!");
20     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
21   }
22   // It's a MemDGNode, so we check if we return either the use-def operand,
23   // or a mem predecessor.
24   if (OpIt != OpItE)
25     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
26   // It's a MemDGNode with OpIt == end, so we need to use MemIt.
27   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
28          "Cant' dereference end iterator!");
29   return *MemIt;
30 }
31 
32 PredIterator &PredIterator::operator++() {
33   // If it's a DGNode then we increment the use-def iterator.
34   if (!isa<MemDGNode>(N)) {
35     assert(OpIt != OpItE && "Already at end!");
36     ++OpIt;
37     // Skip operands that are not instructions.
38     OpIt = skipNonInstr(OpIt, OpItE);
39     return *this;
40   }
41   // It's a MemDGNode, so if we are not at the end of the use-def iterator we
42   // need to first increment that.
43   if (OpIt != OpItE) {
44     ++OpIt;
45     // Skip operands that are not instructions.
46     OpIt = skipNonInstr(OpIt, OpItE);
47     return *this;
48   }
49   // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
50   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
51   ++MemIt;
52   return *this;
53 }
54 
55 bool PredIterator::operator==(const PredIterator &Other) const {
56   assert(DAG == Other.DAG && "Iterators of different DAGs!");
57   assert(N == Other.N && "Iterators of different nodes!");
58   return OpIt == Other.OpIt && MemIt == Other.MemIt;
59 }
60 
61 #ifndef NDEBUG
62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); }
63 void DGNode::dump() const {
64   print(dbgs());
65   dbgs() << "\n";
66 }
67 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
68   I->dumpOS(OS);
69   if (PrintDeps) {
70     // Print memory preds.
71     static constexpr const unsigned Indent = 4;
72     for (auto *Pred : MemPreds) {
73       OS.indent(Indent) << "<-";
74       Pred->print(OS, false);
75       OS << "\n";
76     }
77   }
78 }
79 #endif // NDEBUG
80 
81 MemDGNode *
82 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
83                                           const DependencyGraph &DAG) {
84   Instruction *I = Intvl.top();
85   Instruction *BeforeI = Intvl.bottom();
86   // Walk down the chain looking for a mem-dep candidate instruction.
87   while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
88     I = I->getNextNode();
89   if (!DGNode::isMemDepNodeCandidate(I))
90     return nullptr;
91   return cast<MemDGNode>(DAG.getNode(I));
92 }
93 
94 MemDGNode *
95 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
96                                           const DependencyGraph &DAG) {
97   Instruction *I = Intvl.bottom();
98   Instruction *AfterI = Intvl.top();
99   // Walk up the chain looking for a mem-dep candidate instruction.
100   while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
101     I = I->getPrevNode();
102   if (!DGNode::isMemDepNodeCandidate(I))
103     return nullptr;
104   return cast<MemDGNode>(DAG.getNode(I));
105 }
106 
107 Interval<MemDGNode>
108 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
109                                DependencyGraph &DAG) {
110   auto *TopMemN = getTopMemDGNode(Instrs, DAG);
111   // If we couldn't find a mem node in range TopN - BotN then it's empty.
112   if (TopMemN == nullptr)
113     return {};
114   auto *BotMemN = getBotMemDGNode(Instrs, DAG);
115   assert(BotMemN != nullptr && "TopMemN should be null too!");
116   // Now that we have the mem-dep nodes, create and return the range.
117   return Interval<MemDGNode>(TopMemN, BotMemN);
118 }
119 
120 DependencyGraph::DependencyType
121 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) {
122   // TODO: Perhaps compile-time improvement by skipping if neither is mem?
123   if (FromI->mayWriteToMemory()) {
124     if (ToI->mayReadFromMemory())
125       return DependencyType::ReadAfterWrite;
126     if (ToI->mayWriteToMemory())
127       return DependencyType::WriteAfterWrite;
128   } else if (FromI->mayReadFromMemory()) {
129     if (ToI->mayWriteToMemory())
130       return DependencyType::WriteAfterRead;
131   }
132   if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI))
133     return DependencyType::Control;
134   if (ToI->isTerminator())
135     return DependencyType::Control;
136   if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) ||
137       DGNode::isStackSaveOrRestoreIntrinsic(ToI))
138     return DependencyType::Other;
139   return DependencyType::None;
140 }
141 
142 static bool isOrdered(Instruction *I) {
143   auto IsOrdered = [](Instruction *I) {
144     if (auto *LI = dyn_cast<LoadInst>(I))
145       return !LI->isUnordered();
146     if (auto *SI = dyn_cast<StoreInst>(I))
147       return !SI->isUnordered();
148     if (DGNode::isFenceLike(I))
149       return true;
150     return false;
151   };
152   bool Is = IsOrdered(I);
153   assert((!Is || DGNode::isMemDepCandidate(I)) &&
154          "An ordered instruction must be a MemDepCandidate!");
155   return Is;
156 }
157 
158 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI,
159                             DependencyType DepType) {
160   std::optional<MemoryLocation> DstLocOpt =
161       Utils::memoryLocationGetOrNone(DstI);
162   if (!DstLocOpt)
163     return true;
164   // Check aliasing.
165   assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) &&
166          "Expected a mem instr");
167   // TODO: Check AABudget
168   ModRefInfo SrcModRef =
169       isOrdered(SrcI)
170           ? ModRefInfo::ModRef
171           : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt);
172   switch (DepType) {
173   case DependencyType::ReadAfterWrite:
174   case DependencyType::WriteAfterWrite:
175     return isModSet(SrcModRef);
176   case DependencyType::WriteAfterRead:
177     return isRefSet(SrcModRef);
178   default:
179     llvm_unreachable("Expected only RAW, WAW and WAR!");
180   }
181 }
182 
183 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
184   DependencyType RoughDepType = getRoughDepType(SrcI, DstI);
185   switch (RoughDepType) {
186   case DependencyType::ReadAfterWrite:
187   case DependencyType::WriteAfterWrite:
188   case DependencyType::WriteAfterRead:
189     return alias(SrcI, DstI, RoughDepType);
190   case DependencyType::Control:
191     // Adding actual dep edges from PHIs/to terminator would just create too
192     // many edges, which would be bad for compile-time.
193     // So we ignore them in the DAG formation but handle them in the
194     // scheduler, while sorting the ready list.
195     return false;
196   case DependencyType::Other:
197     return true;
198   case DependencyType::None:
199     return false;
200   }
201   llvm_unreachable("Unknown DependencyType enum");
202 }
203 
204 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
205                                      const Interval<MemDGNode> &SrcScanRange) {
206   assert(isa<MemDGNode>(DstN) &&
207          "DstN is the mem dep destination, so it must be mem");
208   Instruction *DstI = DstN.getInstruction();
209   // Walk up the instruction chain from ScanRange bottom to top, looking for
210   // memory instrs that may alias.
211   for (MemDGNode &SrcN : reverse(SrcScanRange)) {
212     Instruction *SrcI = SrcN.getInstruction();
213     if (hasDep(SrcI, DstI))
214       DstN.addMemPred(&SrcN);
215   }
216 }
217 
218 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
219   // Create Nodes only for the new sections of the DAG.
220   DGNode *LastN = getOrCreateNode(NewInterval.top());
221   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
222   for (Instruction &I : drop_begin(NewInterval)) {
223     auto *N = getOrCreateNode(&I);
224     // Build the Mem node chain.
225     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
226       MemN->setPrevNode(LastMemN);
227       if (LastMemN != nullptr)
228         LastMemN->setNextNode(MemN);
229       LastMemN = MemN;
230     }
231   }
232   // Link new MemDGNode chain with the old one, if any.
233   if (!DAGInterval.empty()) {
234     bool NewIsAbove = NewInterval.comesBefore(DAGInterval);
235     const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
236     const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
237     MemDGNode *LinkTopN =
238         MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
239     MemDGNode *LinkBotN =
240         MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
241     assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
242     if (LinkTopN != nullptr && LinkBotN != nullptr) {
243       LinkTopN->setNextNode(LinkBotN);
244       LinkBotN->setPrevNode(LinkTopN);
245     }
246 #ifndef NDEBUG
247     // TODO: Remove this once we've done enough testing.
248     // Check that the chain is well formed.
249     auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
250     MemDGNode *ChainTopN =
251         MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
252     MemDGNode *ChainBotN =
253         MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
254     if (ChainTopN != nullptr && ChainBotN != nullptr) {
255       for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
256            LastN = N, N = N->getNextNode()) {
257         assert(N == LastN->getNextNode() && "Bad chain!");
258         assert(N->getPrevNode() == LastN && "Bad chain!");
259       }
260     }
261 #endif // NDEBUG
262   }
263 }
264 
265 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
266   if (Instrs.empty())
267     return {};
268 
269   Interval<Instruction> InstrsInterval(Instrs);
270   Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
271   auto NewInterval = Union.getSingleDiff(DAGInterval);
272   if (NewInterval.empty())
273     return {};
274 
275   createNewNodes(NewInterval);
276 
277   // Create the dependencies.
278   //
279   // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval.
280   // +---+       -             -
281   // |   | SrcN  |             |
282   // |   |  |    | SrcRange    |
283   // |New|  v    |             | DstRange
284   // |   | DstN  -             |
285   // |   |                     |
286   // +---+                     -
287   // We are scanning for deps with destination in NewInterval and sources in
288   // NewInterval until DstN, for each DstN.
289   auto FullScan = [this](const Interval<Instruction> Intvl) {
290     auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this);
291     if (!DstRange.empty()) {
292       for (MemDGNode &DstN : drop_begin(DstRange)) {
293         auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
294         scanAndAddDeps(DstN, SrcRange);
295       }
296     }
297   };
298   if (DAGInterval.empty()) {
299     assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
300     FullScan(NewInterval);
301   }
302   // 2. The new section is below the old section.
303   // +---+       -
304   // |   |       |
305   // |Old| SrcN  |
306   // |   |  |    |
307   // +---+  |    | SrcRange
308   // +---+  |    |             -
309   // |   |  |    |             |
310   // |New|  v    |             | DstRange
311   // |   | DstN  -             |
312   // |   |                     |
313   // +---+                     -
314   // We are scanning for deps with destination in NewInterval because the deps
315   // in DAGInterval have already been computed. We consider sources in the whole
316   // range including both NewInterval and DAGInterval until DstN, for each DstN.
317   else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
318     auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
319     auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
320         DAGInterval.getUnionInterval(NewInterval), *this);
321     for (MemDGNode &DstN : DstRange) {
322       auto SrcRange =
323           Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
324       scanAndAddDeps(DstN, SrcRange);
325     }
326   }
327   // 3. The new section is above the old section.
328   else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
329     // +---+       -             -
330     // |   | SrcN  |             |
331     // |New|  |    | SrcRange    | DstRange
332     // |   |  v    |             |
333     // |   | DstN  -             |
334     // |   |                     |
335     // +---+                     -
336     // +---+
337     // |Old|
338     // |   |
339     // +---+
340     // When scanning for deps with destination in NewInterval we need to fully
341     // scan the interval. This is the same as the scanning for a new DAG.
342     FullScan(NewInterval);
343 
344     // +---+       -
345     // |   |       |
346     // |New| SrcN  | SrcRange
347     // |   |  |    |
348     // |   |  |    |
349     // |   |  |    |
350     // +---+  |    -
351     // +---+  |                  -
352     // |Old|  v                  | DstRange
353     // |   | DstN                |
354     // +---+                     -
355     // When scanning for deps with destination in DAGInterval we need to
356     // consider sources from the NewInterval only, because all intra-DAGInterval
357     // dependencies have already been created.
358     auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this);
359     auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
360     for (MemDGNode &DstN : DstRangeOld)
361       scanAndAddDeps(DstN, SrcRange);
362   } else {
363     llvm_unreachable("We don't expect extending in both directions!");
364   }
365 
366   DAGInterval = Union;
367   return NewInterval;
368 }
369 
370 #ifndef NDEBUG
371 void DependencyGraph::print(raw_ostream &OS) const {
372   // InstrToNodeMap is unordered so we need to create an ordered vector.
373   SmallVector<DGNode *> Nodes;
374   Nodes.reserve(InstrToNodeMap.size());
375   for (const auto &Pair : InstrToNodeMap)
376     Nodes.push_back(Pair.second.get());
377   // Sort them based on which one comes first in the BB.
378   sort(Nodes, [](DGNode *N1, DGNode *N2) {
379     return N1->getInstruction()->comesBefore(N2->getInstruction());
380   });
381   for (auto *N : Nodes)
382     N->print(OS, /*PrintDeps=*/true);
383 }
384 
385 void DependencyGraph::dump() const {
386   print(dbgs());
387   dbgs() << "\n";
388 }
389 #endif // NDEBUG
390 
391 } // namespace llvm::sandboxir
392