1 //===- DependencyGraph.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/SandboxIR/Instruction.h" 12 #include "llvm/SandboxIR/Utils.h" 13 14 namespace llvm::sandboxir { 15 16 PredIterator::value_type PredIterator::operator*() { 17 // If it's a DGNode then we dereference the operand iterator. 18 if (!isa<MemDGNode>(N)) { 19 assert(OpIt != OpItE && "Can't dereference end iterator!"); 20 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 21 } 22 // It's a MemDGNode, so we check if we return either the use-def operand, 23 // or a mem predecessor. 24 if (OpIt != OpItE) 25 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 26 // It's a MemDGNode with OpIt == end, so we need to use MemIt. 27 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && 28 "Cant' dereference end iterator!"); 29 return *MemIt; 30 } 31 32 PredIterator &PredIterator::operator++() { 33 // If it's a DGNode then we increment the use-def iterator. 34 if (!isa<MemDGNode>(N)) { 35 assert(OpIt != OpItE && "Already at end!"); 36 ++OpIt; 37 // Skip operands that are not instructions. 38 OpIt = skipNonInstr(OpIt, OpItE); 39 return *this; 40 } 41 // It's a MemDGNode, so if we are not at the end of the use-def iterator we 42 // need to first increment that. 43 if (OpIt != OpItE) { 44 ++OpIt; 45 // Skip operands that are not instructions. 46 OpIt = skipNonInstr(OpIt, OpItE); 47 return *this; 48 } 49 // It's a MemDGNode with OpIt == end, so we need to increment MemIt. 50 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!"); 51 ++MemIt; 52 return *this; 53 } 54 55 bool PredIterator::operator==(const PredIterator &Other) const { 56 assert(DAG == Other.DAG && "Iterators of different DAGs!"); 57 assert(N == Other.N && "Iterators of different nodes!"); 58 return OpIt == Other.OpIt && MemIt == Other.MemIt; 59 } 60 61 #ifndef NDEBUG 62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { 63 OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n"; 64 } 65 void DGNode::dump() const { print(dbgs()); } 66 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { 67 DGNode::print(OS, false); 68 if (PrintDeps) { 69 // Print memory preds. 70 static constexpr const unsigned Indent = 4; 71 for (auto *Pred : MemPreds) 72 OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n"; 73 } 74 } 75 #endif // NDEBUG 76 77 MemDGNode * 78 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, 79 const DependencyGraph &DAG) { 80 Instruction *I = Intvl.top(); 81 Instruction *BeforeI = Intvl.bottom(); 82 // Walk down the chain looking for a mem-dep candidate instruction. 83 while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) 84 I = I->getNextNode(); 85 if (!DGNode::isMemDepNodeCandidate(I)) 86 return nullptr; 87 return cast<MemDGNode>(DAG.getNode(I)); 88 } 89 90 MemDGNode * 91 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, 92 const DependencyGraph &DAG) { 93 Instruction *I = Intvl.bottom(); 94 Instruction *AfterI = Intvl.top(); 95 // Walk up the chain looking for a mem-dep candidate instruction. 96 while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) 97 I = I->getPrevNode(); 98 if (!DGNode::isMemDepNodeCandidate(I)) 99 return nullptr; 100 return cast<MemDGNode>(DAG.getNode(I)); 101 } 102 103 Interval<MemDGNode> 104 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, 105 DependencyGraph &DAG) { 106 auto *TopMemN = getTopMemDGNode(Instrs, DAG); 107 // If we couldn't find a mem node in range TopN - BotN then it's empty. 108 if (TopMemN == nullptr) 109 return {}; 110 auto *BotMemN = getBotMemDGNode(Instrs, DAG); 111 assert(BotMemN != nullptr && "TopMemN should be null too!"); 112 // Now that we have the mem-dep nodes, create and return the range. 113 return Interval<MemDGNode>(TopMemN, BotMemN); 114 } 115 116 DependencyGraph::DependencyType 117 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { 118 // TODO: Perhaps compile-time improvement by skipping if neither is mem? 119 if (FromI->mayWriteToMemory()) { 120 if (ToI->mayReadFromMemory()) 121 return DependencyType::ReadAfterWrite; 122 if (ToI->mayWriteToMemory()) 123 return DependencyType::WriteAfterWrite; 124 } else if (FromI->mayReadFromMemory()) { 125 if (ToI->mayWriteToMemory()) 126 return DependencyType::WriteAfterRead; 127 } 128 if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI)) 129 return DependencyType::Control; 130 if (ToI->isTerminator()) 131 return DependencyType::Control; 132 if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) || 133 DGNode::isStackSaveOrRestoreIntrinsic(ToI)) 134 return DependencyType::Other; 135 return DependencyType::None; 136 } 137 138 static bool isOrdered(Instruction *I) { 139 auto IsOrdered = [](Instruction *I) { 140 if (auto *LI = dyn_cast<LoadInst>(I)) 141 return !LI->isUnordered(); 142 if (auto *SI = dyn_cast<StoreInst>(I)) 143 return !SI->isUnordered(); 144 if (DGNode::isFenceLike(I)) 145 return true; 146 return false; 147 }; 148 bool Is = IsOrdered(I); 149 assert((!Is || DGNode::isMemDepCandidate(I)) && 150 "An ordered instruction must be a MemDepCandidate!"); 151 return Is; 152 } 153 154 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, 155 DependencyType DepType) { 156 std::optional<MemoryLocation> DstLocOpt = 157 Utils::memoryLocationGetOrNone(DstI); 158 if (!DstLocOpt) 159 return true; 160 // Check aliasing. 161 assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && 162 "Expected a mem instr"); 163 // TODO: Check AABudget 164 ModRefInfo SrcModRef = 165 isOrdered(SrcI) 166 ? ModRefInfo::ModRef 167 : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt); 168 switch (DepType) { 169 case DependencyType::ReadAfterWrite: 170 case DependencyType::WriteAfterWrite: 171 return isModSet(SrcModRef); 172 case DependencyType::WriteAfterRead: 173 return isRefSet(SrcModRef); 174 default: 175 llvm_unreachable("Expected only RAW, WAW and WAR!"); 176 } 177 } 178 179 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { 180 DependencyType RoughDepType = getRoughDepType(SrcI, DstI); 181 switch (RoughDepType) { 182 case DependencyType::ReadAfterWrite: 183 case DependencyType::WriteAfterWrite: 184 case DependencyType::WriteAfterRead: 185 return alias(SrcI, DstI, RoughDepType); 186 case DependencyType::Control: 187 // Adding actual dep edges from PHIs/to terminator would just create too 188 // many edges, which would be bad for compile-time. 189 // So we ignore them in the DAG formation but handle them in the 190 // scheduler, while sorting the ready list. 191 return false; 192 case DependencyType::Other: 193 return true; 194 case DependencyType::None: 195 return false; 196 } 197 llvm_unreachable("Unknown DependencyType enum"); 198 } 199 200 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, 201 const Interval<MemDGNode> &SrcScanRange) { 202 assert(isa<MemDGNode>(DstN) && 203 "DstN is the mem dep destination, so it must be mem"); 204 Instruction *DstI = DstN.getInstruction(); 205 // Walk up the instruction chain from ScanRange bottom to top, looking for 206 // memory instrs that may alias. 207 for (MemDGNode &SrcN : reverse(SrcScanRange)) { 208 Instruction *SrcI = SrcN.getInstruction(); 209 if (hasDep(SrcI, DstI)) 210 DstN.addMemPred(&SrcN); 211 } 212 } 213 214 void DependencyGraph::setDefUseUnscheduledSuccs( 215 const Interval<Instruction> &NewInterval) { 216 // +---+ 217 // | | Def 218 // | | | 219 // | | v 220 // | | Use 221 // +---+ 222 // Set the intra-interval counters in NewInterval. 223 for (Instruction &I : NewInterval) { 224 for (Value *Op : I.operands()) { 225 auto *OpI = dyn_cast<Instruction>(Op); 226 if (OpI == nullptr) 227 continue; 228 if (!NewInterval.contains(OpI)) 229 continue; 230 auto *OpN = getNode(OpI); 231 if (OpN == nullptr) 232 continue; 233 ++OpN->UnscheduledSuccs; 234 } 235 } 236 237 // Now handle the cross-interval edges. 238 bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval); 239 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 240 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 241 // +---+ 242 // |Top| 243 // | | Def 244 // +---+ | 245 // | | v 246 // |Bot| Use 247 // | | 248 // +---+ 249 // Walk over all instructions in "BotInterval" and update the counter 250 // of operands that are in "TopInterval". 251 for (Instruction &BotI : BotInterval) { 252 auto *BotN = getNode(&BotI); 253 // Skip scheduled nodes. 254 if (BotN->scheduled()) 255 continue; 256 for (Value *Op : BotI.operands()) { 257 auto *OpI = dyn_cast<Instruction>(Op); 258 if (OpI == nullptr) 259 continue; 260 if (!TopInterval.contains(OpI)) 261 continue; 262 auto *OpN = getNode(OpI); 263 if (OpN == nullptr) 264 continue; 265 ++OpN->UnscheduledSuccs; 266 } 267 } 268 } 269 270 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { 271 // Create Nodes only for the new sections of the DAG. 272 DGNode *LastN = getOrCreateNode(NewInterval.top()); 273 MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN); 274 for (Instruction &I : drop_begin(NewInterval)) { 275 auto *N = getOrCreateNode(&I); 276 // Build the Mem node chain. 277 if (auto *MemN = dyn_cast<MemDGNode>(N)) { 278 MemN->setPrevNode(LastMemN); 279 if (LastMemN != nullptr) 280 LastMemN->setNextNode(MemN); 281 LastMemN = MemN; 282 } 283 } 284 // Link new MemDGNode chain with the old one, if any. 285 if (!DAGInterval.empty()) { 286 bool NewIsAbove = NewInterval.comesBefore(DAGInterval); 287 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 288 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 289 MemDGNode *LinkTopN = 290 MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this); 291 MemDGNode *LinkBotN = 292 MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this); 293 assert((LinkTopN == nullptr || LinkBotN == nullptr || 294 LinkTopN->comesBefore(LinkBotN)) && 295 "Wrong order!"); 296 if (LinkTopN != nullptr && LinkBotN != nullptr) { 297 LinkTopN->setNextNode(LinkBotN); 298 LinkBotN->setPrevNode(LinkTopN); 299 } 300 #ifndef NDEBUG 301 // TODO: Remove this once we've done enough testing. 302 // Check that the chain is well formed. 303 auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); 304 MemDGNode *ChainTopN = 305 MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); 306 MemDGNode *ChainBotN = 307 MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); 308 if (ChainTopN != nullptr && ChainBotN != nullptr) { 309 for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; 310 LastN = N, N = N->getNextNode()) { 311 assert(N == LastN->getNextNode() && "Bad chain!"); 312 assert(N->getPrevNode() == LastN && "Bad chain!"); 313 } 314 } 315 #endif // NDEBUG 316 } 317 318 setDefUseUnscheduledSuccs(NewInterval); 319 } 320 321 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { 322 if (Instrs.empty()) 323 return {}; 324 325 Interval<Instruction> InstrsInterval(Instrs); 326 Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval); 327 auto NewInterval = Union.getSingleDiff(DAGInterval); 328 if (NewInterval.empty()) 329 return {}; 330 331 createNewNodes(NewInterval); 332 333 // Create the dependencies. 334 // 335 // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. 336 // +---+ - - 337 // | | SrcN | | 338 // | | | | SrcRange | 339 // |New| v | | DstRange 340 // | | DstN - | 341 // | | | 342 // +---+ - 343 // We are scanning for deps with destination in NewInterval and sources in 344 // NewInterval until DstN, for each DstN. 345 auto FullScan = [this](const Interval<Instruction> Intvl) { 346 auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this); 347 if (!DstRange.empty()) { 348 for (MemDGNode &DstN : drop_begin(DstRange)) { 349 auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); 350 scanAndAddDeps(DstN, SrcRange); 351 } 352 } 353 }; 354 if (DAGInterval.empty()) { 355 assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); 356 FullScan(NewInterval); 357 } 358 // 2. The new section is below the old section. 359 // +---+ - 360 // | | | 361 // |Old| SrcN | 362 // | | | | 363 // +---+ | | SrcRange 364 // +---+ | | - 365 // | | | | | 366 // |New| v | | DstRange 367 // | | DstN - | 368 // | | | 369 // +---+ - 370 // We are scanning for deps with destination in NewInterval because the deps 371 // in DAGInterval have already been computed. We consider sources in the whole 372 // range including both NewInterval and DAGInterval until DstN, for each DstN. 373 else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { 374 auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 375 auto SrcRangeFull = MemDGNodeIntervalBuilder::make( 376 DAGInterval.getUnionInterval(NewInterval), *this); 377 for (MemDGNode &DstN : DstRange) { 378 auto SrcRange = 379 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); 380 scanAndAddDeps(DstN, SrcRange); 381 } 382 } 383 // 3. The new section is above the old section. 384 else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) { 385 // +---+ - - 386 // | | SrcN | | 387 // |New| | | SrcRange | DstRange 388 // | | v | | 389 // | | DstN - | 390 // | | | 391 // +---+ - 392 // +---+ 393 // |Old| 394 // | | 395 // +---+ 396 // When scanning for deps with destination in NewInterval we need to fully 397 // scan the interval. This is the same as the scanning for a new DAG. 398 FullScan(NewInterval); 399 400 // +---+ - 401 // | | | 402 // |New| SrcN | SrcRange 403 // | | | | 404 // | | | | 405 // | | | | 406 // +---+ | - 407 // +---+ | - 408 // |Old| v | DstRange 409 // | | DstN | 410 // +---+ - 411 // When scanning for deps with destination in DAGInterval we need to 412 // consider sources from the NewInterval only, because all intra-DAGInterval 413 // dependencies have already been created. 414 auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this); 415 auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 416 for (MemDGNode &DstN : DstRangeOld) 417 scanAndAddDeps(DstN, SrcRange); 418 } else { 419 llvm_unreachable("We don't expect extending in both directions!"); 420 } 421 422 DAGInterval = Union; 423 return NewInterval; 424 } 425 426 #ifndef NDEBUG 427 void DependencyGraph::print(raw_ostream &OS) const { 428 // InstrToNodeMap is unordered so we need to create an ordered vector. 429 SmallVector<DGNode *> Nodes; 430 Nodes.reserve(InstrToNodeMap.size()); 431 for (const auto &Pair : InstrToNodeMap) 432 Nodes.push_back(Pair.second.get()); 433 // Sort them based on which one comes first in the BB. 434 sort(Nodes, [](DGNode *N1, DGNode *N2) { 435 return N1->getInstruction()->comesBefore(N2->getInstruction()); 436 }); 437 for (auto *N : Nodes) 438 N->print(OS, /*PrintDeps=*/true); 439 } 440 441 void DependencyGraph::dump() const { 442 print(dbgs()); 443 dbgs() << "\n"; 444 } 445 #endif // NDEBUG 446 447 } // namespace llvm::sandboxir 448