xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (revision 69c0067927293bff1401a9a050081e83dbefd282)
1 //===- DependencyGraph.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/SandboxIR/Instruction.h"
12 #include "llvm/SandboxIR/Utils.h"
13 
14 namespace llvm::sandboxir {
15 
16 PredIterator::value_type PredIterator::operator*() {
17   // If it's a DGNode then we dereference the operand iterator.
18   if (!isa<MemDGNode>(N)) {
19     assert(OpIt != OpItE && "Can't dereference end iterator!");
20     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
21   }
22   // It's a MemDGNode, so we check if we return either the use-def operand,
23   // or a mem predecessor.
24   if (OpIt != OpItE)
25     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
26   // It's a MemDGNode with OpIt == end, so we need to use MemIt.
27   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
28          "Cant' dereference end iterator!");
29   return *MemIt;
30 }
31 
32 PredIterator &PredIterator::operator++() {
33   // If it's a DGNode then we increment the use-def iterator.
34   if (!isa<MemDGNode>(N)) {
35     assert(OpIt != OpItE && "Already at end!");
36     ++OpIt;
37     // Skip operands that are not instructions.
38     OpIt = skipNonInstr(OpIt, OpItE);
39     return *this;
40   }
41   // It's a MemDGNode, so if we are not at the end of the use-def iterator we
42   // need to first increment that.
43   if (OpIt != OpItE) {
44     ++OpIt;
45     // Skip operands that are not instructions.
46     OpIt = skipNonInstr(OpIt, OpItE);
47     return *this;
48   }
49   // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
50   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
51   ++MemIt;
52   return *this;
53 }
54 
55 bool PredIterator::operator==(const PredIterator &Other) const {
56   assert(DAG == Other.DAG && "Iterators of different DAGs!");
57   assert(N == Other.N && "Iterators of different nodes!");
58   return OpIt == Other.OpIt && MemIt == Other.MemIt;
59 }
60 
61 #ifndef NDEBUG
62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); }
63 void DGNode::dump() const {
64   print(dbgs());
65   dbgs() << "\n";
66 }
67 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
68   I->dumpOS(OS);
69   if (PrintDeps) {
70     // Print memory preds.
71     static constexpr const unsigned Indent = 4;
72     for (auto *Pred : MemPreds) {
73       OS.indent(Indent) << "<-";
74       Pred->print(OS, false);
75       OS << "\n";
76     }
77   }
78 }
79 #endif // NDEBUG
80 
81 MemDGNode *
82 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
83                                           const DependencyGraph &DAG) {
84   Instruction *I = Intvl.top();
85   Instruction *BeforeI = Intvl.bottom();
86   // Walk down the chain looking for a mem-dep candidate instruction.
87   while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
88     I = I->getNextNode();
89   if (!DGNode::isMemDepNodeCandidate(I))
90     return nullptr;
91   return cast<MemDGNode>(DAG.getNode(I));
92 }
93 
94 MemDGNode *
95 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
96                                           const DependencyGraph &DAG) {
97   Instruction *I = Intvl.bottom();
98   Instruction *AfterI = Intvl.top();
99   // Walk up the chain looking for a mem-dep candidate instruction.
100   while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
101     I = I->getPrevNode();
102   if (!DGNode::isMemDepNodeCandidate(I))
103     return nullptr;
104   return cast<MemDGNode>(DAG.getNode(I));
105 }
106 
107 Interval<MemDGNode>
108 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
109                                DependencyGraph &DAG) {
110   auto *TopMemN = getTopMemDGNode(Instrs, DAG);
111   // If we couldn't find a mem node in range TopN - BotN then it's empty.
112   if (TopMemN == nullptr)
113     return {};
114   auto *BotMemN = getBotMemDGNode(Instrs, DAG);
115   assert(BotMemN != nullptr && "TopMemN should be null too!");
116   // Now that we have the mem-dep nodes, create and return the range.
117   return Interval<MemDGNode>(TopMemN, BotMemN);
118 }
119 
120 DependencyGraph::DependencyType
121 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) {
122   // TODO: Perhaps compile-time improvement by skipping if neither is mem?
123   if (FromI->mayWriteToMemory()) {
124     if (ToI->mayReadFromMemory())
125       return DependencyType::ReadAfterWrite;
126     if (ToI->mayWriteToMemory())
127       return DependencyType::WriteAfterWrite;
128   } else if (FromI->mayReadFromMemory()) {
129     if (ToI->mayWriteToMemory())
130       return DependencyType::WriteAfterRead;
131   }
132   if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI))
133     return DependencyType::Control;
134   if (ToI->isTerminator())
135     return DependencyType::Control;
136   if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) ||
137       DGNode::isStackSaveOrRestoreIntrinsic(ToI))
138     return DependencyType::Other;
139   return DependencyType::None;
140 }
141 
142 static bool isOrdered(Instruction *I) {
143   auto IsOrdered = [](Instruction *I) {
144     if (auto *LI = dyn_cast<LoadInst>(I))
145       return !LI->isUnordered();
146     if (auto *SI = dyn_cast<StoreInst>(I))
147       return !SI->isUnordered();
148     if (DGNode::isFenceLike(I))
149       return true;
150     return false;
151   };
152   bool Is = IsOrdered(I);
153   assert((!Is || DGNode::isMemDepCandidate(I)) &&
154          "An ordered instruction must be a MemDepCandidate!");
155   return Is;
156 }
157 
158 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI,
159                             DependencyType DepType) {
160   std::optional<MemoryLocation> DstLocOpt =
161       Utils::memoryLocationGetOrNone(DstI);
162   if (!DstLocOpt)
163     return true;
164   // Check aliasing.
165   assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) &&
166          "Expected a mem instr");
167   // TODO: Check AABudget
168   ModRefInfo SrcModRef =
169       isOrdered(SrcI)
170           ? ModRefInfo::ModRef
171           : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt);
172   switch (DepType) {
173   case DependencyType::ReadAfterWrite:
174   case DependencyType::WriteAfterWrite:
175     return isModSet(SrcModRef);
176   case DependencyType::WriteAfterRead:
177     return isRefSet(SrcModRef);
178   default:
179     llvm_unreachable("Expected only RAW, WAW and WAR!");
180   }
181 }
182 
183 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
184   DependencyType RoughDepType = getRoughDepType(SrcI, DstI);
185   switch (RoughDepType) {
186   case DependencyType::ReadAfterWrite:
187   case DependencyType::WriteAfterWrite:
188   case DependencyType::WriteAfterRead:
189     return alias(SrcI, DstI, RoughDepType);
190   case DependencyType::Control:
191     // Adding actual dep edges from PHIs/to terminator would just create too
192     // many edges, which would be bad for compile-time.
193     // So we ignore them in the DAG formation but handle them in the
194     // scheduler, while sorting the ready list.
195     return false;
196   case DependencyType::Other:
197     return true;
198   case DependencyType::None:
199     return false;
200   }
201   llvm_unreachable("Unknown DependencyType enum");
202 }
203 
204 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
205                                      const Interval<MemDGNode> &SrcScanRange) {
206   assert(isa<MemDGNode>(DstN) &&
207          "DstN is the mem dep destination, so it must be mem");
208   Instruction *DstI = DstN.getInstruction();
209   // Walk up the instruction chain from ScanRange bottom to top, looking for
210   // memory instrs that may alias.
211   for (MemDGNode &SrcN : reverse(SrcScanRange)) {
212     Instruction *SrcI = SrcN.getInstruction();
213     if (hasDep(SrcI, DstI))
214       DstN.addMemPred(&SrcN);
215   }
216 }
217 
218 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
219   if (Instrs.empty())
220     return {};
221 
222   Interval<Instruction> InstrInterval(Instrs);
223 
224   DGNode *LastN = getOrCreateNode(InstrInterval.top());
225   // Create DGNodes for all instrs in Interval to avoid future Instruction to
226   // DGNode lookups.
227   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
228   for (Instruction &I : drop_begin(InstrInterval)) {
229     auto *N = getOrCreateNode(&I);
230     // Build the Mem node chain.
231     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
232       MemN->setPrevNode(LastMemN);
233       if (LastMemN != nullptr)
234         LastMemN->setNextNode(MemN);
235       LastMemN = MemN;
236     }
237   }
238   // Create the dependencies.
239   auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this);
240   if (!DstRange.empty()) {
241     for (MemDGNode &DstN : drop_begin(DstRange)) {
242       auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
243       scanAndAddDeps(DstN, SrcRange);
244     }
245   }
246 
247   return InstrInterval;
248 }
249 
250 #ifndef NDEBUG
251 void DependencyGraph::print(raw_ostream &OS) const {
252   // InstrToNodeMap is unordered so we need to create an ordered vector.
253   SmallVector<DGNode *> Nodes;
254   Nodes.reserve(InstrToNodeMap.size());
255   for (const auto &Pair : InstrToNodeMap)
256     Nodes.push_back(Pair.second.get());
257   // Sort them based on which one comes first in the BB.
258   sort(Nodes, [](DGNode *N1, DGNode *N2) {
259     return N1->getInstruction()->comesBefore(N2->getInstruction());
260   });
261   for (auto *N : Nodes)
262     N->print(OS, /*PrintDeps=*/true);
263 }
264 
265 void DependencyGraph::dump() const {
266   print(dbgs());
267   dbgs() << "\n";
268 }
269 #endif // NDEBUG
270 
271 } // namespace llvm::sandboxir
272