1 //===- DependencyGraph.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/SandboxIR/Instruction.h" 12 #include "llvm/SandboxIR/Utils.h" 13 14 namespace llvm::sandboxir { 15 16 PredIterator::value_type PredIterator::operator*() { 17 // If it's a DGNode then we dereference the operand iterator. 18 if (!isa<MemDGNode>(N)) { 19 assert(OpIt != OpItE && "Can't dereference end iterator!"); 20 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 21 } 22 // It's a MemDGNode, so we check if we return either the use-def operand, 23 // or a mem predecessor. 24 if (OpIt != OpItE) 25 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 26 // It's a MemDGNode with OpIt == end, so we need to use MemIt. 27 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && 28 "Cant' dereference end iterator!"); 29 return *MemIt; 30 } 31 32 PredIterator &PredIterator::operator++() { 33 // If it's a DGNode then we increment the use-def iterator. 34 if (!isa<MemDGNode>(N)) { 35 assert(OpIt != OpItE && "Already at end!"); 36 ++OpIt; 37 // Skip operands that are not instructions. 38 OpIt = skipNonInstr(OpIt, OpItE); 39 return *this; 40 } 41 // It's a MemDGNode, so if we are not at the end of the use-def iterator we 42 // need to first increment that. 43 if (OpIt != OpItE) { 44 ++OpIt; 45 // Skip operands that are not instructions. 46 OpIt = skipNonInstr(OpIt, OpItE); 47 return *this; 48 } 49 // It's a MemDGNode with OpIt == end, so we need to increment MemIt. 50 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!"); 51 ++MemIt; 52 return *this; 53 } 54 55 bool PredIterator::operator==(const PredIterator &Other) const { 56 assert(DAG == Other.DAG && "Iterators of different DAGs!"); 57 assert(N == Other.N && "Iterators of different nodes!"); 58 return OpIt == Other.OpIt && MemIt == Other.MemIt; 59 } 60 61 #ifndef NDEBUG 62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); } 63 void DGNode::dump() const { 64 print(dbgs()); 65 dbgs() << "\n"; 66 } 67 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { 68 I->dumpOS(OS); 69 if (PrintDeps) { 70 // Print memory preds. 71 static constexpr const unsigned Indent = 4; 72 for (auto *Pred : MemPreds) { 73 OS.indent(Indent) << "<-"; 74 Pred->print(OS, false); 75 OS << "\n"; 76 } 77 } 78 } 79 #endif // NDEBUG 80 81 MemDGNode * 82 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, 83 const DependencyGraph &DAG) { 84 Instruction *I = Intvl.top(); 85 Instruction *BeforeI = Intvl.bottom(); 86 // Walk down the chain looking for a mem-dep candidate instruction. 87 while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) 88 I = I->getNextNode(); 89 if (!DGNode::isMemDepNodeCandidate(I)) 90 return nullptr; 91 return cast<MemDGNode>(DAG.getNode(I)); 92 } 93 94 MemDGNode * 95 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, 96 const DependencyGraph &DAG) { 97 Instruction *I = Intvl.bottom(); 98 Instruction *AfterI = Intvl.top(); 99 // Walk up the chain looking for a mem-dep candidate instruction. 100 while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) 101 I = I->getPrevNode(); 102 if (!DGNode::isMemDepNodeCandidate(I)) 103 return nullptr; 104 return cast<MemDGNode>(DAG.getNode(I)); 105 } 106 107 Interval<MemDGNode> 108 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, 109 DependencyGraph &DAG) { 110 auto *TopMemN = getTopMemDGNode(Instrs, DAG); 111 // If we couldn't find a mem node in range TopN - BotN then it's empty. 112 if (TopMemN == nullptr) 113 return {}; 114 auto *BotMemN = getBotMemDGNode(Instrs, DAG); 115 assert(BotMemN != nullptr && "TopMemN should be null too!"); 116 // Now that we have the mem-dep nodes, create and return the range. 117 return Interval<MemDGNode>(TopMemN, BotMemN); 118 } 119 120 DependencyGraph::DependencyType 121 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { 122 // TODO: Perhaps compile-time improvement by skipping if neither is mem? 123 if (FromI->mayWriteToMemory()) { 124 if (ToI->mayReadFromMemory()) 125 return DependencyType::ReadAfterWrite; 126 if (ToI->mayWriteToMemory()) 127 return DependencyType::WriteAfterWrite; 128 } else if (FromI->mayReadFromMemory()) { 129 if (ToI->mayWriteToMemory()) 130 return DependencyType::WriteAfterRead; 131 } 132 if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI)) 133 return DependencyType::Control; 134 if (ToI->isTerminator()) 135 return DependencyType::Control; 136 if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) || 137 DGNode::isStackSaveOrRestoreIntrinsic(ToI)) 138 return DependencyType::Other; 139 return DependencyType::None; 140 } 141 142 static bool isOrdered(Instruction *I) { 143 auto IsOrdered = [](Instruction *I) { 144 if (auto *LI = dyn_cast<LoadInst>(I)) 145 return !LI->isUnordered(); 146 if (auto *SI = dyn_cast<StoreInst>(I)) 147 return !SI->isUnordered(); 148 if (DGNode::isFenceLike(I)) 149 return true; 150 return false; 151 }; 152 bool Is = IsOrdered(I); 153 assert((!Is || DGNode::isMemDepCandidate(I)) && 154 "An ordered instruction must be a MemDepCandidate!"); 155 return Is; 156 } 157 158 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, 159 DependencyType DepType) { 160 std::optional<MemoryLocation> DstLocOpt = 161 Utils::memoryLocationGetOrNone(DstI); 162 if (!DstLocOpt) 163 return true; 164 // Check aliasing. 165 assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && 166 "Expected a mem instr"); 167 // TODO: Check AABudget 168 ModRefInfo SrcModRef = 169 isOrdered(SrcI) 170 ? ModRefInfo::ModRef 171 : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt); 172 switch (DepType) { 173 case DependencyType::ReadAfterWrite: 174 case DependencyType::WriteAfterWrite: 175 return isModSet(SrcModRef); 176 case DependencyType::WriteAfterRead: 177 return isRefSet(SrcModRef); 178 default: 179 llvm_unreachable("Expected only RAW, WAW and WAR!"); 180 } 181 } 182 183 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { 184 DependencyType RoughDepType = getRoughDepType(SrcI, DstI); 185 switch (RoughDepType) { 186 case DependencyType::ReadAfterWrite: 187 case DependencyType::WriteAfterWrite: 188 case DependencyType::WriteAfterRead: 189 return alias(SrcI, DstI, RoughDepType); 190 case DependencyType::Control: 191 // Adding actual dep edges from PHIs/to terminator would just create too 192 // many edges, which would be bad for compile-time. 193 // So we ignore them in the DAG formation but handle them in the 194 // scheduler, while sorting the ready list. 195 return false; 196 case DependencyType::Other: 197 return true; 198 case DependencyType::None: 199 return false; 200 } 201 llvm_unreachable("Unknown DependencyType enum"); 202 } 203 204 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, 205 const Interval<MemDGNode> &SrcScanRange) { 206 assert(isa<MemDGNode>(DstN) && 207 "DstN is the mem dep destination, so it must be mem"); 208 Instruction *DstI = DstN.getInstruction(); 209 // Walk up the instruction chain from ScanRange bottom to top, looking for 210 // memory instrs that may alias. 211 for (MemDGNode &SrcN : reverse(SrcScanRange)) { 212 Instruction *SrcI = SrcN.getInstruction(); 213 if (hasDep(SrcI, DstI)) 214 DstN.addMemPred(&SrcN); 215 } 216 } 217 218 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { 219 // Create Nodes only for the new sections of the DAG. 220 DGNode *LastN = getOrCreateNode(NewInterval.top()); 221 MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN); 222 for (Instruction &I : drop_begin(NewInterval)) { 223 auto *N = getOrCreateNode(&I); 224 // Build the Mem node chain. 225 if (auto *MemN = dyn_cast<MemDGNode>(N)) { 226 MemN->setPrevNode(LastMemN); 227 if (LastMemN != nullptr) 228 LastMemN->setNextNode(MemN); 229 LastMemN = MemN; 230 } 231 } 232 // Link new MemDGNode chain with the old one, if any. 233 if (!DAGInterval.empty()) { 234 // TODO: Implement Interval::comesBefore() to replace this check. 235 bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.top()); 236 assert( 237 (NewIsAbove || DAGInterval.bottom()->comesBefore(NewInterval.top())) && 238 "Expected NewInterval below DAGInterval."); 239 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 240 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 241 MemDGNode *LinkTopN = 242 MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this); 243 MemDGNode *LinkBotN = 244 MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this); 245 assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!"); 246 if (LinkTopN != nullptr && LinkBotN != nullptr) { 247 LinkTopN->setNextNode(LinkBotN); 248 LinkBotN->setPrevNode(LinkTopN); 249 } 250 #ifndef NDEBUG 251 // TODO: Remove this once we've done enough testing. 252 // Check that the chain is well formed. 253 auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); 254 MemDGNode *ChainTopN = 255 MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); 256 MemDGNode *ChainBotN = 257 MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); 258 if (ChainTopN != nullptr && ChainBotN != nullptr) { 259 for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; 260 LastN = N, N = N->getNextNode()) { 261 assert(N == LastN->getNextNode() && "Bad chain!"); 262 assert(N->getPrevNode() == LastN && "Bad chain!"); 263 } 264 } 265 #endif // NDEBUG 266 } 267 } 268 269 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { 270 if (Instrs.empty()) 271 return {}; 272 273 Interval<Instruction> InstrsInterval(Instrs); 274 Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval); 275 auto NewInterval = Union.getSingleDiff(DAGInterval); 276 if (NewInterval.empty()) 277 return {}; 278 279 createNewNodes(NewInterval); 280 281 // Create the dependencies. 282 // 283 // 1. DAGInterval empty 2. New is below Old 3. New is above old 284 // ------------------------ ------------------- ------------------- 285 // Scan: DstN: Scan: 286 // +---+ -ScanTopN +---+DstTopN -ScanTopN 287 // | | | |New| | 288 // |Old| | +---+ -ScanBotN 289 // | | | +---+ 290 // DstN: Scan: +---+DstN: | | | 291 // +---+DstTopN -ScanTopN +---+DstTopN | |Old| 292 // |New| | |New| | | | 293 // +---+DstBotN -ScanBotN +---+DstBotN -ScanBotN +---+DstBotN 294 295 // 1. This is a new DAG. 296 if (DAGInterval.empty()) { 297 assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); 298 auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 299 if (!DstRange.empty()) { 300 for (MemDGNode &DstN : drop_begin(DstRange)) { 301 auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); 302 scanAndAddDeps(DstN, SrcRange); 303 } 304 } 305 } 306 // 2. The new section is below the old section. 307 else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { 308 auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 309 auto SrcRangeFull = MemDGNodeIntervalBuilder::make( 310 DAGInterval.getUnionInterval(NewInterval), *this); 311 for (MemDGNode &DstN : DstRange) { 312 auto SrcRange = 313 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); 314 scanAndAddDeps(DstN, SrcRange); 315 } 316 } 317 // 3. The new section is above the old section. 318 else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) { 319 auto DstRange = MemDGNodeIntervalBuilder::make( 320 NewInterval.getUnionInterval(DAGInterval), *this); 321 auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this); 322 if (!DstRange.empty()) { 323 for (MemDGNode &DstN : drop_begin(DstRange)) { 324 auto SrcRange = 325 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); 326 scanAndAddDeps(DstN, SrcRange); 327 } 328 } 329 } else { 330 llvm_unreachable("We don't expect extending in both directions!"); 331 } 332 333 DAGInterval = Union; 334 return NewInterval; 335 } 336 337 #ifndef NDEBUG 338 void DependencyGraph::print(raw_ostream &OS) const { 339 // InstrToNodeMap is unordered so we need to create an ordered vector. 340 SmallVector<DGNode *> Nodes; 341 Nodes.reserve(InstrToNodeMap.size()); 342 for (const auto &Pair : InstrToNodeMap) 343 Nodes.push_back(Pair.second.get()); 344 // Sort them based on which one comes first in the BB. 345 sort(Nodes, [](DGNode *N1, DGNode *N2) { 346 return N1->getInstruction()->comesBefore(N2->getInstruction()); 347 }); 348 for (auto *N : Nodes) 349 N->print(OS, /*PrintDeps=*/true); 350 } 351 352 void DependencyGraph::dump() const { 353 print(dbgs()); 354 dbgs() << "\n"; 355 } 356 #endif // NDEBUG 357 358 } // namespace llvm::sandboxir 359