xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (revision fc08ad6610c66856f48559e543eb7be317e908e7)
1 //===- DependencyGraph.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/SandboxIR/Instruction.h"
12 #include "llvm/SandboxIR/Utils.h"
13 
14 namespace llvm::sandboxir {
15 
16 PredIterator::value_type PredIterator::operator*() {
17   // If it's a DGNode then we dereference the operand iterator.
18   if (!isa<MemDGNode>(N)) {
19     assert(OpIt != OpItE && "Can't dereference end iterator!");
20     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
21   }
22   // It's a MemDGNode, so we check if we return either the use-def operand,
23   // or a mem predecessor.
24   if (OpIt != OpItE)
25     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
26   // It's a MemDGNode with OpIt == end, so we need to use MemIt.
27   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
28          "Cant' dereference end iterator!");
29   return *MemIt;
30 }
31 
32 PredIterator &PredIterator::operator++() {
33   // If it's a DGNode then we increment the use-def iterator.
34   if (!isa<MemDGNode>(N)) {
35     assert(OpIt != OpItE && "Already at end!");
36     ++OpIt;
37     // Skip operands that are not instructions.
38     OpIt = skipNonInstr(OpIt, OpItE);
39     return *this;
40   }
41   // It's a MemDGNode, so if we are not at the end of the use-def iterator we
42   // need to first increment that.
43   if (OpIt != OpItE) {
44     ++OpIt;
45     // Skip operands that are not instructions.
46     OpIt = skipNonInstr(OpIt, OpItE);
47     return *this;
48   }
49   // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
50   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
51   ++MemIt;
52   return *this;
53 }
54 
55 bool PredIterator::operator==(const PredIterator &Other) const {
56   assert(DAG == Other.DAG && "Iterators of different DAGs!");
57   assert(N == Other.N && "Iterators of different nodes!");
58   return OpIt == Other.OpIt && MemIt == Other.MemIt;
59 }
60 
61 #ifndef NDEBUG
62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
63   OS << *I << " USuccs:" << UnscheduledSuccs << "\n";
64 }
65 void DGNode::dump() const { print(dbgs()); }
66 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
67   DGNode::print(OS, false);
68   if (PrintDeps) {
69     // Print memory preds.
70     static constexpr const unsigned Indent = 4;
71     for (auto *Pred : MemPreds)
72       OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n";
73   }
74 }
75 #endif // NDEBUG
76 
77 MemDGNode *
78 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
79                                           const DependencyGraph &DAG) {
80   Instruction *I = Intvl.top();
81   Instruction *BeforeI = Intvl.bottom();
82   // Walk down the chain looking for a mem-dep candidate instruction.
83   while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
84     I = I->getNextNode();
85   if (!DGNode::isMemDepNodeCandidate(I))
86     return nullptr;
87   return cast<MemDGNode>(DAG.getNode(I));
88 }
89 
90 MemDGNode *
91 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
92                                           const DependencyGraph &DAG) {
93   Instruction *I = Intvl.bottom();
94   Instruction *AfterI = Intvl.top();
95   // Walk up the chain looking for a mem-dep candidate instruction.
96   while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
97     I = I->getPrevNode();
98   if (!DGNode::isMemDepNodeCandidate(I))
99     return nullptr;
100   return cast<MemDGNode>(DAG.getNode(I));
101 }
102 
103 Interval<MemDGNode>
104 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
105                                DependencyGraph &DAG) {
106   auto *TopMemN = getTopMemDGNode(Instrs, DAG);
107   // If we couldn't find a mem node in range TopN - BotN then it's empty.
108   if (TopMemN == nullptr)
109     return {};
110   auto *BotMemN = getBotMemDGNode(Instrs, DAG);
111   assert(BotMemN != nullptr && "TopMemN should be null too!");
112   // Now that we have the mem-dep nodes, create and return the range.
113   return Interval<MemDGNode>(TopMemN, BotMemN);
114 }
115 
116 DependencyGraph::DependencyType
117 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) {
118   // TODO: Perhaps compile-time improvement by skipping if neither is mem?
119   if (FromI->mayWriteToMemory()) {
120     if (ToI->mayReadFromMemory())
121       return DependencyType::ReadAfterWrite;
122     if (ToI->mayWriteToMemory())
123       return DependencyType::WriteAfterWrite;
124   } else if (FromI->mayReadFromMemory()) {
125     if (ToI->mayWriteToMemory())
126       return DependencyType::WriteAfterRead;
127   }
128   if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI))
129     return DependencyType::Control;
130   if (ToI->isTerminator())
131     return DependencyType::Control;
132   if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) ||
133       DGNode::isStackSaveOrRestoreIntrinsic(ToI))
134     return DependencyType::Other;
135   return DependencyType::None;
136 }
137 
138 static bool isOrdered(Instruction *I) {
139   auto IsOrdered = [](Instruction *I) {
140     if (auto *LI = dyn_cast<LoadInst>(I))
141       return !LI->isUnordered();
142     if (auto *SI = dyn_cast<StoreInst>(I))
143       return !SI->isUnordered();
144     if (DGNode::isFenceLike(I))
145       return true;
146     return false;
147   };
148   bool Is = IsOrdered(I);
149   assert((!Is || DGNode::isMemDepCandidate(I)) &&
150          "An ordered instruction must be a MemDepCandidate!");
151   return Is;
152 }
153 
154 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI,
155                             DependencyType DepType) {
156   std::optional<MemoryLocation> DstLocOpt =
157       Utils::memoryLocationGetOrNone(DstI);
158   if (!DstLocOpt)
159     return true;
160   // Check aliasing.
161   assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) &&
162          "Expected a mem instr");
163   // TODO: Check AABudget
164   ModRefInfo SrcModRef =
165       isOrdered(SrcI)
166           ? ModRefInfo::ModRef
167           : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt);
168   switch (DepType) {
169   case DependencyType::ReadAfterWrite:
170   case DependencyType::WriteAfterWrite:
171     return isModSet(SrcModRef);
172   case DependencyType::WriteAfterRead:
173     return isRefSet(SrcModRef);
174   default:
175     llvm_unreachable("Expected only RAW, WAW and WAR!");
176   }
177 }
178 
179 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
180   DependencyType RoughDepType = getRoughDepType(SrcI, DstI);
181   switch (RoughDepType) {
182   case DependencyType::ReadAfterWrite:
183   case DependencyType::WriteAfterWrite:
184   case DependencyType::WriteAfterRead:
185     return alias(SrcI, DstI, RoughDepType);
186   case DependencyType::Control:
187     // Adding actual dep edges from PHIs/to terminator would just create too
188     // many edges, which would be bad for compile-time.
189     // So we ignore them in the DAG formation but handle them in the
190     // scheduler, while sorting the ready list.
191     return false;
192   case DependencyType::Other:
193     return true;
194   case DependencyType::None:
195     return false;
196   }
197   llvm_unreachable("Unknown DependencyType enum");
198 }
199 
200 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
201                                      const Interval<MemDGNode> &SrcScanRange) {
202   assert(isa<MemDGNode>(DstN) &&
203          "DstN is the mem dep destination, so it must be mem");
204   Instruction *DstI = DstN.getInstruction();
205   // Walk up the instruction chain from ScanRange bottom to top, looking for
206   // memory instrs that may alias.
207   for (MemDGNode &SrcN : reverse(SrcScanRange)) {
208     Instruction *SrcI = SrcN.getInstruction();
209     if (hasDep(SrcI, DstI))
210       DstN.addMemPred(&SrcN);
211   }
212 }
213 
214 void DependencyGraph::setDefUseUnscheduledSuccs(
215     const Interval<Instruction> &NewInterval) {
216   // +---+
217   // |   |  Def
218   // |   |   |
219   // |   |   v
220   // |   |  Use
221   // +---+
222   // Set the intra-interval counters in NewInterval.
223   for (Instruction &I : NewInterval) {
224     for (Value *Op : I.operands()) {
225       auto *OpI = dyn_cast<Instruction>(Op);
226       if (OpI == nullptr)
227         continue;
228       if (!NewInterval.contains(OpI))
229         continue;
230       auto *OpN = getNode(OpI);
231       if (OpN == nullptr)
232         continue;
233       ++OpN->UnscheduledSuccs;
234     }
235   }
236 
237   // Now handle the cross-interval edges.
238   bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval);
239   const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
240   const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
241   // +---+
242   // |Top|
243   // |   |  Def
244   // +---+   |
245   // |   |   v
246   // |Bot|  Use
247   // |   |
248   // +---+
249   // Walk over all instructions in "BotInterval" and update the counter
250   // of operands that are in "TopInterval".
251   for (Instruction &BotI : BotInterval) {
252     for (Value *Op : BotI.operands()) {
253       auto *OpI = dyn_cast<Instruction>(Op);
254       if (OpI == nullptr)
255         continue;
256       if (!TopInterval.contains(OpI))
257         continue;
258       auto *OpN = getNode(OpI);
259       if (OpN == nullptr)
260         continue;
261       ++OpN->UnscheduledSuccs;
262     }
263   }
264 }
265 
266 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
267   // Create Nodes only for the new sections of the DAG.
268   DGNode *LastN = getOrCreateNode(NewInterval.top());
269   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
270   for (Instruction &I : drop_begin(NewInterval)) {
271     auto *N = getOrCreateNode(&I);
272     // Build the Mem node chain.
273     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
274       MemN->setPrevNode(LastMemN);
275       if (LastMemN != nullptr)
276         LastMemN->setNextNode(MemN);
277       LastMemN = MemN;
278     }
279   }
280   // Link new MemDGNode chain with the old one, if any.
281   if (!DAGInterval.empty()) {
282     bool NewIsAbove = NewInterval.comesBefore(DAGInterval);
283     const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
284     const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
285     MemDGNode *LinkTopN =
286         MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
287     MemDGNode *LinkBotN =
288         MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
289     assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
290     if (LinkTopN != nullptr && LinkBotN != nullptr) {
291       LinkTopN->setNextNode(LinkBotN);
292       LinkBotN->setPrevNode(LinkTopN);
293     }
294 #ifndef NDEBUG
295     // TODO: Remove this once we've done enough testing.
296     // Check that the chain is well formed.
297     auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
298     MemDGNode *ChainTopN =
299         MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
300     MemDGNode *ChainBotN =
301         MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
302     if (ChainTopN != nullptr && ChainBotN != nullptr) {
303       for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
304            LastN = N, N = N->getNextNode()) {
305         assert(N == LastN->getNextNode() && "Bad chain!");
306         assert(N->getPrevNode() == LastN && "Bad chain!");
307       }
308     }
309 #endif // NDEBUG
310   }
311 
312   setDefUseUnscheduledSuccs(NewInterval);
313 }
314 
315 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
316   if (Instrs.empty())
317     return {};
318 
319   Interval<Instruction> InstrsInterval(Instrs);
320   Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
321   auto NewInterval = Union.getSingleDiff(DAGInterval);
322   if (NewInterval.empty())
323     return {};
324 
325   createNewNodes(NewInterval);
326 
327   // Create the dependencies.
328   //
329   // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval.
330   // +---+       -             -
331   // |   | SrcN  |             |
332   // |   |  |    | SrcRange    |
333   // |New|  v    |             | DstRange
334   // |   | DstN  -             |
335   // |   |                     |
336   // +---+                     -
337   // We are scanning for deps with destination in NewInterval and sources in
338   // NewInterval until DstN, for each DstN.
339   auto FullScan = [this](const Interval<Instruction> Intvl) {
340     auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this);
341     if (!DstRange.empty()) {
342       for (MemDGNode &DstN : drop_begin(DstRange)) {
343         auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
344         scanAndAddDeps(DstN, SrcRange);
345       }
346     }
347   };
348   if (DAGInterval.empty()) {
349     assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
350     FullScan(NewInterval);
351   }
352   // 2. The new section is below the old section.
353   // +---+       -
354   // |   |       |
355   // |Old| SrcN  |
356   // |   |  |    |
357   // +---+  |    | SrcRange
358   // +---+  |    |             -
359   // |   |  |    |             |
360   // |New|  v    |             | DstRange
361   // |   | DstN  -             |
362   // |   |                     |
363   // +---+                     -
364   // We are scanning for deps with destination in NewInterval because the deps
365   // in DAGInterval have already been computed. We consider sources in the whole
366   // range including both NewInterval and DAGInterval until DstN, for each DstN.
367   else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
368     auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
369     auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
370         DAGInterval.getUnionInterval(NewInterval), *this);
371     for (MemDGNode &DstN : DstRange) {
372       auto SrcRange =
373           Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
374       scanAndAddDeps(DstN, SrcRange);
375     }
376   }
377   // 3. The new section is above the old section.
378   else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
379     // +---+       -             -
380     // |   | SrcN  |             |
381     // |New|  |    | SrcRange    | DstRange
382     // |   |  v    |             |
383     // |   | DstN  -             |
384     // |   |                     |
385     // +---+                     -
386     // +---+
387     // |Old|
388     // |   |
389     // +---+
390     // When scanning for deps with destination in NewInterval we need to fully
391     // scan the interval. This is the same as the scanning for a new DAG.
392     FullScan(NewInterval);
393 
394     // +---+       -
395     // |   |       |
396     // |New| SrcN  | SrcRange
397     // |   |  |    |
398     // |   |  |    |
399     // |   |  |    |
400     // +---+  |    -
401     // +---+  |                  -
402     // |Old|  v                  | DstRange
403     // |   | DstN                |
404     // +---+                     -
405     // When scanning for deps with destination in DAGInterval we need to
406     // consider sources from the NewInterval only, because all intra-DAGInterval
407     // dependencies have already been created.
408     auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this);
409     auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
410     for (MemDGNode &DstN : DstRangeOld)
411       scanAndAddDeps(DstN, SrcRange);
412   } else {
413     llvm_unreachable("We don't expect extending in both directions!");
414   }
415 
416   DAGInterval = Union;
417   return NewInterval;
418 }
419 
420 #ifndef NDEBUG
421 void DependencyGraph::print(raw_ostream &OS) const {
422   // InstrToNodeMap is unordered so we need to create an ordered vector.
423   SmallVector<DGNode *> Nodes;
424   Nodes.reserve(InstrToNodeMap.size());
425   for (const auto &Pair : InstrToNodeMap)
426     Nodes.push_back(Pair.second.get());
427   // Sort them based on which one comes first in the BB.
428   sort(Nodes, [](DGNode *N1, DGNode *N2) {
429     return N1->getInstruction()->comesBefore(N2->getInstruction());
430   });
431   for (auto *N : Nodes)
432     N->print(OS, /*PrintDeps=*/true);
433 }
434 
435 void DependencyGraph::dump() const {
436   print(dbgs());
437   dbgs() << "\n";
438 }
439 #endif // NDEBUG
440 
441 } // namespace llvm::sandboxir
442