xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (revision e8dd95e97bd45c8ee3cc2a3d95c9a6198a970d80)
1 //===- DependencyGraph.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/SandboxIR/Instruction.h"
12 #include "llvm/SandboxIR/Utils.h"
13 
14 namespace llvm::sandboxir {
15 
16 PredIterator::value_type PredIterator::operator*() {
17   // If it's a DGNode then we dereference the operand iterator.
18   if (!isa<MemDGNode>(N)) {
19     assert(OpIt != OpItE && "Can't dereference end iterator!");
20     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
21   }
22   // It's a MemDGNode, so we check if we return either the use-def operand,
23   // or a mem predecessor.
24   if (OpIt != OpItE)
25     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
26   // It's a MemDGNode with OpIt == end, so we need to use MemIt.
27   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
28          "Cant' dereference end iterator!");
29   return *MemIt;
30 }
31 
32 PredIterator &PredIterator::operator++() {
33   // If it's a DGNode then we increment the use-def iterator.
34   if (!isa<MemDGNode>(N)) {
35     assert(OpIt != OpItE && "Already at end!");
36     ++OpIt;
37     // Skip operands that are not instructions.
38     OpIt = skipNonInstr(OpIt, OpItE);
39     return *this;
40   }
41   // It's a MemDGNode, so if we are not at the end of the use-def iterator we
42   // need to first increment that.
43   if (OpIt != OpItE) {
44     ++OpIt;
45     // Skip operands that are not instructions.
46     OpIt = skipNonInstr(OpIt, OpItE);
47     return *this;
48   }
49   // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
50   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
51   ++MemIt;
52   return *this;
53 }
54 
55 bool PredIterator::operator==(const PredIterator &Other) const {
56   assert(DAG == Other.DAG && "Iterators of different DAGs!");
57   assert(N == Other.N && "Iterators of different nodes!");
58   return OpIt == Other.OpIt && MemIt == Other.MemIt;
59 }
60 
61 #ifndef NDEBUG
62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); }
63 void DGNode::dump() const {
64   print(dbgs());
65   dbgs() << "\n";
66 }
67 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
68   I->dumpOS(OS);
69   if (PrintDeps) {
70     // Print memory preds.
71     static constexpr const unsigned Indent = 4;
72     for (auto *Pred : MemPreds) {
73       OS.indent(Indent) << "<-";
74       Pred->print(OS, false);
75       OS << "\n";
76     }
77   }
78 }
79 #endif // NDEBUG
80 
81 MemDGNode *
82 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
83                                           const DependencyGraph &DAG) {
84   Instruction *I = Intvl.top();
85   Instruction *BeforeI = Intvl.bottom();
86   // Walk down the chain looking for a mem-dep candidate instruction.
87   while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
88     I = I->getNextNode();
89   if (!DGNode::isMemDepNodeCandidate(I))
90     return nullptr;
91   return cast<MemDGNode>(DAG.getNode(I));
92 }
93 
94 MemDGNode *
95 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
96                                           const DependencyGraph &DAG) {
97   Instruction *I = Intvl.bottom();
98   Instruction *AfterI = Intvl.top();
99   // Walk up the chain looking for a mem-dep candidate instruction.
100   while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
101     I = I->getPrevNode();
102   if (!DGNode::isMemDepNodeCandidate(I))
103     return nullptr;
104   return cast<MemDGNode>(DAG.getNode(I));
105 }
106 
107 Interval<MemDGNode>
108 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
109                                DependencyGraph &DAG) {
110   auto *TopMemN = getTopMemDGNode(Instrs, DAG);
111   // If we couldn't find a mem node in range TopN - BotN then it's empty.
112   if (TopMemN == nullptr)
113     return {};
114   auto *BotMemN = getBotMemDGNode(Instrs, DAG);
115   assert(BotMemN != nullptr && "TopMemN should be null too!");
116   // Now that we have the mem-dep nodes, create and return the range.
117   return Interval<MemDGNode>(TopMemN, BotMemN);
118 }
119 
120 DependencyGraph::DependencyType
121 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) {
122   // TODO: Perhaps compile-time improvement by skipping if neither is mem?
123   if (FromI->mayWriteToMemory()) {
124     if (ToI->mayReadFromMemory())
125       return DependencyType::ReadAfterWrite;
126     if (ToI->mayWriteToMemory())
127       return DependencyType::WriteAfterWrite;
128   } else if (FromI->mayReadFromMemory()) {
129     if (ToI->mayWriteToMemory())
130       return DependencyType::WriteAfterRead;
131   }
132   if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI))
133     return DependencyType::Control;
134   if (ToI->isTerminator())
135     return DependencyType::Control;
136   if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) ||
137       DGNode::isStackSaveOrRestoreIntrinsic(ToI))
138     return DependencyType::Other;
139   return DependencyType::None;
140 }
141 
142 static bool isOrdered(Instruction *I) {
143   auto IsOrdered = [](Instruction *I) {
144     if (auto *LI = dyn_cast<LoadInst>(I))
145       return !LI->isUnordered();
146     if (auto *SI = dyn_cast<StoreInst>(I))
147       return !SI->isUnordered();
148     if (DGNode::isFenceLike(I))
149       return true;
150     return false;
151   };
152   bool Is = IsOrdered(I);
153   assert((!Is || DGNode::isMemDepCandidate(I)) &&
154          "An ordered instruction must be a MemDepCandidate!");
155   return Is;
156 }
157 
158 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI,
159                             DependencyType DepType) {
160   std::optional<MemoryLocation> DstLocOpt =
161       Utils::memoryLocationGetOrNone(DstI);
162   if (!DstLocOpt)
163     return true;
164   // Check aliasing.
165   assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) &&
166          "Expected a mem instr");
167   // TODO: Check AABudget
168   ModRefInfo SrcModRef =
169       isOrdered(SrcI)
170           ? ModRefInfo::ModRef
171           : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt);
172   switch (DepType) {
173   case DependencyType::ReadAfterWrite:
174   case DependencyType::WriteAfterWrite:
175     return isModSet(SrcModRef);
176   case DependencyType::WriteAfterRead:
177     return isRefSet(SrcModRef);
178   default:
179     llvm_unreachable("Expected only RAW, WAW and WAR!");
180   }
181 }
182 
183 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
184   DependencyType RoughDepType = getRoughDepType(SrcI, DstI);
185   switch (RoughDepType) {
186   case DependencyType::ReadAfterWrite:
187   case DependencyType::WriteAfterWrite:
188   case DependencyType::WriteAfterRead:
189     return alias(SrcI, DstI, RoughDepType);
190   case DependencyType::Control:
191     // Adding actual dep edges from PHIs/to terminator would just create too
192     // many edges, which would be bad for compile-time.
193     // So we ignore them in the DAG formation but handle them in the
194     // scheduler, while sorting the ready list.
195     return false;
196   case DependencyType::Other:
197     return true;
198   case DependencyType::None:
199     return false;
200   }
201   llvm_unreachable("Unknown DependencyType enum");
202 }
203 
204 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
205                                      const Interval<MemDGNode> &SrcScanRange) {
206   assert(isa<MemDGNode>(DstN) &&
207          "DstN is the mem dep destination, so it must be mem");
208   Instruction *DstI = DstN.getInstruction();
209   // Walk up the instruction chain from ScanRange bottom to top, looking for
210   // memory instrs that may alias.
211   for (MemDGNode &SrcN : reverse(SrcScanRange)) {
212     Instruction *SrcI = SrcN.getInstruction();
213     if (hasDep(SrcI, DstI))
214       DstN.addMemPred(&SrcN);
215   }
216 }
217 
218 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
219   // Create Nodes only for the new sections of the DAG.
220   DGNode *LastN = getOrCreateNode(NewInterval.top());
221   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
222   for (Instruction &I : drop_begin(NewInterval)) {
223     auto *N = getOrCreateNode(&I);
224     // Build the Mem node chain.
225     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
226       MemN->setPrevNode(LastMemN);
227       if (LastMemN != nullptr)
228         LastMemN->setNextNode(MemN);
229       LastMemN = MemN;
230     }
231   }
232   // Link new MemDGNode chain with the old one, if any.
233   if (!DAGInterval.empty()) {
234     // TODO: Implement Interval::comesBefore() to replace this check.
235     bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.top());
236     assert(
237         (NewIsAbove || DAGInterval.bottom()->comesBefore(NewInterval.top())) &&
238         "Expected NewInterval below DAGInterval.");
239     const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
240     const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
241     MemDGNode *LinkTopN =
242         MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
243     MemDGNode *LinkBotN =
244         MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
245     assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
246     if (LinkTopN != nullptr && LinkBotN != nullptr) {
247       LinkTopN->setNextNode(LinkBotN);
248       LinkBotN->setPrevNode(LinkTopN);
249     }
250 #ifndef NDEBUG
251     // TODO: Remove this once we've done enough testing.
252     // Check that the chain is well formed.
253     auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
254     MemDGNode *ChainTopN =
255         MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
256     MemDGNode *ChainBotN =
257         MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
258     if (ChainTopN != nullptr && ChainBotN != nullptr) {
259       for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
260            LastN = N, N = N->getNextNode()) {
261         assert(N == LastN->getNextNode() && "Bad chain!");
262         assert(N->getPrevNode() == LastN && "Bad chain!");
263       }
264     }
265 #endif // NDEBUG
266   }
267 }
268 
269 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
270   if (Instrs.empty())
271     return {};
272 
273   Interval<Instruction> InstrsInterval(Instrs);
274   Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
275   auto NewInterval = Union.getSingleDiff(DAGInterval);
276   if (NewInterval.empty())
277     return {};
278 
279   createNewNodes(NewInterval);
280 
281   // Create the dependencies.
282   //
283   // 1. DAGInterval empty      2. New is below Old     3. New is above old
284   // ------------------------  -------------------      -------------------
285   //                                         Scan:           DstN:    Scan:
286   //                           +---+         -ScanTopN  +---+DstTopN  -ScanTopN
287   //                           |   |         |          |New|         |
288   //                           |Old|         |          +---+         -ScanBotN
289   //                           |   |         |          +---+
290   //      DstN:    Scan:       +---+DstN:    |          |   |
291   // +---+DstTopN  -ScanTopN   +---+DstTopN  |          |Old|
292   // |New|         |           |New|         |          |   |
293   // +---+DstBotN  -ScanBotN   +---+DstBotN  -ScanBotN  +---+DstBotN
294 
295   // 1. This is a new DAG.
296   if (DAGInterval.empty()) {
297     assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
298     auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
299     if (!DstRange.empty()) {
300       for (MemDGNode &DstN : drop_begin(DstRange)) {
301         auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
302         scanAndAddDeps(DstN, SrcRange);
303       }
304     }
305   }
306   // 2. The new section is below the old section.
307   else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
308     auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
309     auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
310         DAGInterval.getUnionInterval(NewInterval), *this);
311     for (MemDGNode &DstN : DstRange) {
312       auto SrcRange =
313           Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
314       scanAndAddDeps(DstN, SrcRange);
315     }
316   }
317   // 3. The new section is above the old section.
318   else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
319     auto DstRange = MemDGNodeIntervalBuilder::make(
320         NewInterval.getUnionInterval(DAGInterval), *this);
321     auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this);
322     if (!DstRange.empty()) {
323       for (MemDGNode &DstN : drop_begin(DstRange)) {
324         auto SrcRange =
325             Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
326         scanAndAddDeps(DstN, SrcRange);
327       }
328     }
329   } else {
330     llvm_unreachable("We don't expect extending in both directions!");
331   }
332 
333   DAGInterval = Union;
334   return NewInterval;
335 }
336 
337 #ifndef NDEBUG
338 void DependencyGraph::print(raw_ostream &OS) const {
339   // InstrToNodeMap is unordered so we need to create an ordered vector.
340   SmallVector<DGNode *> Nodes;
341   Nodes.reserve(InstrToNodeMap.size());
342   for (const auto &Pair : InstrToNodeMap)
343     Nodes.push_back(Pair.second.get());
344   // Sort them based on which one comes first in the BB.
345   sort(Nodes, [](DGNode *N1, DGNode *N2) {
346     return N1->getInstruction()->comesBefore(N2->getInstruction());
347   });
348   for (auto *N : Nodes)
349     N->print(OS, /*PrintDeps=*/true);
350 }
351 
352 void DependencyGraph::dump() const {
353   print(dbgs());
354   dbgs() << "\n";
355 }
356 #endif // NDEBUG
357 
358 } // namespace llvm::sandboxir
359