1 //===- DependencyGraph.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/SandboxIR/Instruction.h" 12 #include "llvm/SandboxIR/Utils.h" 13 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" 14 15 namespace llvm::sandboxir { 16 17 PredIterator::value_type PredIterator::operator*() { 18 // If it's a DGNode then we dereference the operand iterator. 19 if (!isa<MemDGNode>(N)) { 20 assert(OpIt != OpItE && "Can't dereference end iterator!"); 21 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 22 } 23 // It's a MemDGNode, so we check if we return either the use-def operand, 24 // or a mem predecessor. 25 if (OpIt != OpItE) 26 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 27 // It's a MemDGNode with OpIt == end, so we need to use MemIt. 28 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && 29 "Cant' dereference end iterator!"); 30 return *MemIt; 31 } 32 33 PredIterator &PredIterator::operator++() { 34 // If it's a DGNode then we increment the use-def iterator. 35 if (!isa<MemDGNode>(N)) { 36 assert(OpIt != OpItE && "Already at end!"); 37 ++OpIt; 38 // Skip operands that are not instructions. 39 OpIt = skipNonInstr(OpIt, OpItE); 40 return *this; 41 } 42 // It's a MemDGNode, so if we are not at the end of the use-def iterator we 43 // need to first increment that. 44 if (OpIt != OpItE) { 45 ++OpIt; 46 // Skip operands that are not instructions. 47 OpIt = skipNonInstr(OpIt, OpItE); 48 return *this; 49 } 50 // It's a MemDGNode with OpIt == end, so we need to increment MemIt. 51 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!"); 52 ++MemIt; 53 return *this; 54 } 55 56 bool PredIterator::operator==(const PredIterator &Other) const { 57 assert(DAG == Other.DAG && "Iterators of different DAGs!"); 58 assert(N == Other.N && "Iterators of different nodes!"); 59 return OpIt == Other.OpIt && MemIt == Other.MemIt; 60 } 61 62 DGNode::~DGNode() { 63 if (SB == nullptr) 64 return; 65 SB->eraseFromBundle(this); 66 } 67 68 #ifndef NDEBUG 69 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { 70 OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n"; 71 } 72 void DGNode::dump() const { print(dbgs()); } 73 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { 74 DGNode::print(OS, false); 75 if (PrintDeps) { 76 // Print memory preds. 77 static constexpr const unsigned Indent = 4; 78 for (auto *Pred : MemPreds) 79 OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n"; 80 } 81 } 82 #endif // NDEBUG 83 84 MemDGNode * 85 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, 86 const DependencyGraph &DAG) { 87 Instruction *I = Intvl.top(); 88 Instruction *BeforeI = Intvl.bottom(); 89 // Walk down the chain looking for a mem-dep candidate instruction. 90 while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) 91 I = I->getNextNode(); 92 if (!DGNode::isMemDepNodeCandidate(I)) 93 return nullptr; 94 return cast<MemDGNode>(DAG.getNode(I)); 95 } 96 97 MemDGNode * 98 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, 99 const DependencyGraph &DAG) { 100 Instruction *I = Intvl.bottom(); 101 Instruction *AfterI = Intvl.top(); 102 // Walk up the chain looking for a mem-dep candidate instruction. 103 while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) 104 I = I->getPrevNode(); 105 if (!DGNode::isMemDepNodeCandidate(I)) 106 return nullptr; 107 return cast<MemDGNode>(DAG.getNode(I)); 108 } 109 110 Interval<MemDGNode> 111 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, 112 DependencyGraph &DAG) { 113 auto *TopMemN = getTopMemDGNode(Instrs, DAG); 114 // If we couldn't find a mem node in range TopN - BotN then it's empty. 115 if (TopMemN == nullptr) 116 return {}; 117 auto *BotMemN = getBotMemDGNode(Instrs, DAG); 118 assert(BotMemN != nullptr && "TopMemN should be null too!"); 119 // Now that we have the mem-dep nodes, create and return the range. 120 return Interval<MemDGNode>(TopMemN, BotMemN); 121 } 122 123 DependencyGraph::DependencyType 124 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { 125 // TODO: Perhaps compile-time improvement by skipping if neither is mem? 126 if (FromI->mayWriteToMemory()) { 127 if (ToI->mayReadFromMemory()) 128 return DependencyType::ReadAfterWrite; 129 if (ToI->mayWriteToMemory()) 130 return DependencyType::WriteAfterWrite; 131 } else if (FromI->mayReadFromMemory()) { 132 if (ToI->mayWriteToMemory()) 133 return DependencyType::WriteAfterRead; 134 } 135 if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI)) 136 return DependencyType::Control; 137 if (ToI->isTerminator()) 138 return DependencyType::Control; 139 if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) || 140 DGNode::isStackSaveOrRestoreIntrinsic(ToI)) 141 return DependencyType::Other; 142 return DependencyType::None; 143 } 144 145 static bool isOrdered(Instruction *I) { 146 auto IsOrdered = [](Instruction *I) { 147 if (auto *LI = dyn_cast<LoadInst>(I)) 148 return !LI->isUnordered(); 149 if (auto *SI = dyn_cast<StoreInst>(I)) 150 return !SI->isUnordered(); 151 if (DGNode::isFenceLike(I)) 152 return true; 153 return false; 154 }; 155 bool Is = IsOrdered(I); 156 assert((!Is || DGNode::isMemDepCandidate(I)) && 157 "An ordered instruction must be a MemDepCandidate!"); 158 return Is; 159 } 160 161 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, 162 DependencyType DepType) { 163 std::optional<MemoryLocation> DstLocOpt = 164 Utils::memoryLocationGetOrNone(DstI); 165 if (!DstLocOpt) 166 return true; 167 // Check aliasing. 168 assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && 169 "Expected a mem instr"); 170 // TODO: Check AABudget 171 ModRefInfo SrcModRef = 172 isOrdered(SrcI) 173 ? ModRefInfo::ModRef 174 : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt); 175 switch (DepType) { 176 case DependencyType::ReadAfterWrite: 177 case DependencyType::WriteAfterWrite: 178 return isModSet(SrcModRef); 179 case DependencyType::WriteAfterRead: 180 return isRefSet(SrcModRef); 181 default: 182 llvm_unreachable("Expected only RAW, WAW and WAR!"); 183 } 184 } 185 186 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { 187 DependencyType RoughDepType = getRoughDepType(SrcI, DstI); 188 switch (RoughDepType) { 189 case DependencyType::ReadAfterWrite: 190 case DependencyType::WriteAfterWrite: 191 case DependencyType::WriteAfterRead: 192 return alias(SrcI, DstI, RoughDepType); 193 case DependencyType::Control: 194 // Adding actual dep edges from PHIs/to terminator would just create too 195 // many edges, which would be bad for compile-time. 196 // So we ignore them in the DAG formation but handle them in the 197 // scheduler, while sorting the ready list. 198 return false; 199 case DependencyType::Other: 200 return true; 201 case DependencyType::None: 202 return false; 203 } 204 llvm_unreachable("Unknown DependencyType enum"); 205 } 206 207 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, 208 const Interval<MemDGNode> &SrcScanRange) { 209 assert(isa<MemDGNode>(DstN) && 210 "DstN is the mem dep destination, so it must be mem"); 211 Instruction *DstI = DstN.getInstruction(); 212 // Walk up the instruction chain from ScanRange bottom to top, looking for 213 // memory instrs that may alias. 214 for (MemDGNode &SrcN : reverse(SrcScanRange)) { 215 Instruction *SrcI = SrcN.getInstruction(); 216 if (hasDep(SrcI, DstI)) 217 DstN.addMemPred(&SrcN); 218 } 219 } 220 221 void DependencyGraph::setDefUseUnscheduledSuccs( 222 const Interval<Instruction> &NewInterval) { 223 // +---+ 224 // | | Def 225 // | | | 226 // | | v 227 // | | Use 228 // +---+ 229 // Set the intra-interval counters in NewInterval. 230 for (Instruction &I : NewInterval) { 231 for (Value *Op : I.operands()) { 232 auto *OpI = dyn_cast<Instruction>(Op); 233 if (OpI == nullptr) 234 continue; 235 if (!NewInterval.contains(OpI)) 236 continue; 237 auto *OpN = getNode(OpI); 238 if (OpN == nullptr) 239 continue; 240 ++OpN->UnscheduledSuccs; 241 } 242 } 243 244 // Now handle the cross-interval edges. 245 bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval); 246 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 247 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 248 // +---+ 249 // |Top| 250 // | | Def 251 // +---+ | 252 // | | v 253 // |Bot| Use 254 // | | 255 // +---+ 256 // Walk over all instructions in "BotInterval" and update the counter 257 // of operands that are in "TopInterval". 258 for (Instruction &BotI : BotInterval) { 259 auto *BotN = getNode(&BotI); 260 // Skip scheduled nodes. 261 if (BotN->scheduled()) 262 continue; 263 for (Value *Op : BotI.operands()) { 264 auto *OpI = dyn_cast<Instruction>(Op); 265 if (OpI == nullptr) 266 continue; 267 if (!TopInterval.contains(OpI)) 268 continue; 269 auto *OpN = getNode(OpI); 270 if (OpN == nullptr) 271 continue; 272 ++OpN->UnscheduledSuccs; 273 } 274 } 275 } 276 277 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { 278 // Create Nodes only for the new sections of the DAG. 279 DGNode *LastN = getOrCreateNode(NewInterval.top()); 280 MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN); 281 for (Instruction &I : drop_begin(NewInterval)) { 282 auto *N = getOrCreateNode(&I); 283 // Build the Mem node chain. 284 if (auto *MemN = dyn_cast<MemDGNode>(N)) { 285 MemN->setPrevNode(LastMemN); 286 if (LastMemN != nullptr) 287 LastMemN->setNextNode(MemN); 288 LastMemN = MemN; 289 } 290 } 291 // Link new MemDGNode chain with the old one, if any. 292 if (!DAGInterval.empty()) { 293 bool NewIsAbove = NewInterval.comesBefore(DAGInterval); 294 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 295 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 296 MemDGNode *LinkTopN = 297 MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this); 298 MemDGNode *LinkBotN = 299 MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this); 300 assert((LinkTopN == nullptr || LinkBotN == nullptr || 301 LinkTopN->comesBefore(LinkBotN)) && 302 "Wrong order!"); 303 if (LinkTopN != nullptr && LinkBotN != nullptr) { 304 LinkTopN->setNextNode(LinkBotN); 305 LinkBotN->setPrevNode(LinkTopN); 306 } 307 #ifndef NDEBUG 308 // TODO: Remove this once we've done enough testing. 309 // Check that the chain is well formed. 310 auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); 311 MemDGNode *ChainTopN = 312 MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); 313 MemDGNode *ChainBotN = 314 MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); 315 if (ChainTopN != nullptr && ChainBotN != nullptr) { 316 for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; 317 LastN = N, N = N->getNextNode()) { 318 assert(N == LastN->getNextNode() && "Bad chain!"); 319 assert(N->getPrevNode() == LastN && "Bad chain!"); 320 } 321 } 322 #endif // NDEBUG 323 } 324 325 setDefUseUnscheduledSuccs(NewInterval); 326 } 327 328 MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, 329 bool IncludingN) const { 330 auto *I = N->getInstruction(); 331 for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr; 332 PrevI = PrevI->getPrevNode()) { 333 auto *PrevN = getNodeOrNull(PrevI); 334 if (PrevN == nullptr) 335 return nullptr; 336 if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN)) 337 return PrevMemN; 338 } 339 return nullptr; 340 } 341 342 MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, 343 bool IncludingN) const { 344 auto *I = N->getInstruction(); 345 for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr; 346 NextI = NextI->getNextNode()) { 347 auto *NextN = getNodeOrNull(NextI); 348 if (NextN == nullptr) 349 return nullptr; 350 if (auto *NextMemN = dyn_cast<MemDGNode>(NextN)) 351 return NextMemN; 352 } 353 return nullptr; 354 } 355 356 void DependencyGraph::notifyCreateInstr(Instruction *I) { 357 auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I)); 358 // TODO: Update the dependencies for the new node. 359 360 // Update the MemDGNode chain if this is a memory node. 361 if (MemN != nullptr) { 362 if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) { 363 PrevMemN->NextMemN = MemN; 364 MemN->PrevMemN = PrevMemN; 365 } 366 if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) { 367 NextMemN->PrevMemN = MemN; 368 MemN->NextMemN = NextMemN; 369 } 370 } 371 } 372 373 void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) { 374 // Early return if `I` doesn't actually move. 375 BasicBlock *BB = To.getNodeParent(); 376 if (To != BB->end() && &*To == I->getNextNode()) 377 return; 378 379 // Maintain the DAGInterval. 380 DAGInterval.notifyMoveInstr(I, To); 381 382 // TODO: Perhaps check if this is legal by checking the dependencies? 383 384 // Update the MemDGNode chain to reflect the instr movement if necessary. 385 DGNode *N = getNodeOrNull(I); 386 if (N == nullptr) 387 return; 388 MemDGNode *MemN = dyn_cast<MemDGNode>(N); 389 if (MemN == nullptr) 390 return; 391 // First detach it from the existing chain. 392 MemN->detachFromChain(); 393 // Now insert it back into the chain at the new location. 394 if (To != BB->end()) { 395 DGNode *ToN = getNodeOrNull(&*To); 396 if (ToN != nullptr) { 397 MemDGNode *PrevMemN = getMemDGNodeBefore(ToN, /*IncludingN=*/false); 398 MemDGNode *NextMemN = getMemDGNodeAfter(ToN, /*IncludingN=*/true); 399 MemN->PrevMemN = PrevMemN; 400 if (PrevMemN != nullptr) 401 PrevMemN->NextMemN = MemN; 402 MemN->NextMemN = NextMemN; 403 if (NextMemN != nullptr) 404 NextMemN->PrevMemN = MemN; 405 } 406 } else { 407 // MemN becomes the last instruction in the BB. 408 auto *TermN = getNodeOrNull(BB->getTerminator()); 409 if (TermN != nullptr) { 410 MemDGNode *PrevMemN = getMemDGNodeBefore(TermN, /*IncludingN=*/false); 411 PrevMemN->NextMemN = MemN; 412 MemN->PrevMemN = PrevMemN; 413 } else { 414 // The terminator is outside the DAG interval so do nothing. 415 } 416 } 417 } 418 419 void DependencyGraph::notifyEraseInstr(Instruction *I) { 420 // Update the MemDGNode chain if this is a memory node. 421 if (auto *MemN = dyn_cast_or_null<MemDGNode>(getNodeOrNull(I))) { 422 auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false); 423 auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false); 424 if (PrevMemN != nullptr) 425 PrevMemN->NextMemN = NextMemN; 426 if (NextMemN != nullptr) 427 NextMemN->PrevMemN = PrevMemN; 428 } 429 430 InstrToNodeMap.erase(I); 431 432 // TODO: Update the dependencies. 433 } 434 435 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { 436 if (Instrs.empty()) 437 return {}; 438 439 Interval<Instruction> InstrsInterval(Instrs); 440 Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval); 441 auto NewInterval = Union.getSingleDiff(DAGInterval); 442 if (NewInterval.empty()) 443 return {}; 444 445 createNewNodes(NewInterval); 446 447 // Create the dependencies. 448 // 449 // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. 450 // +---+ - - 451 // | | SrcN | | 452 // | | | | SrcRange | 453 // |New| v | | DstRange 454 // | | DstN - | 455 // | | | 456 // +---+ - 457 // We are scanning for deps with destination in NewInterval and sources in 458 // NewInterval until DstN, for each DstN. 459 auto FullScan = [this](const Interval<Instruction> Intvl) { 460 auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this); 461 if (!DstRange.empty()) { 462 for (MemDGNode &DstN : drop_begin(DstRange)) { 463 auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); 464 scanAndAddDeps(DstN, SrcRange); 465 } 466 } 467 }; 468 if (DAGInterval.empty()) { 469 assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); 470 FullScan(NewInterval); 471 } 472 // 2. The new section is below the old section. 473 // +---+ - 474 // | | | 475 // |Old| SrcN | 476 // | | | | 477 // +---+ | | SrcRange 478 // +---+ | | - 479 // | | | | | 480 // |New| v | | DstRange 481 // | | DstN - | 482 // | | | 483 // +---+ - 484 // We are scanning for deps with destination in NewInterval because the deps 485 // in DAGInterval have already been computed. We consider sources in the whole 486 // range including both NewInterval and DAGInterval until DstN, for each DstN. 487 else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { 488 auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 489 auto SrcRangeFull = MemDGNodeIntervalBuilder::make( 490 DAGInterval.getUnionInterval(NewInterval), *this); 491 for (MemDGNode &DstN : DstRange) { 492 auto SrcRange = 493 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); 494 scanAndAddDeps(DstN, SrcRange); 495 } 496 } 497 // 3. The new section is above the old section. 498 else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) { 499 // +---+ - - 500 // | | SrcN | | 501 // |New| | | SrcRange | DstRange 502 // | | v | | 503 // | | DstN - | 504 // | | | 505 // +---+ - 506 // +---+ 507 // |Old| 508 // | | 509 // +---+ 510 // When scanning for deps with destination in NewInterval we need to fully 511 // scan the interval. This is the same as the scanning for a new DAG. 512 FullScan(NewInterval); 513 514 // +---+ - 515 // | | | 516 // |New| SrcN | SrcRange 517 // | | | | 518 // | | | | 519 // | | | | 520 // +---+ | - 521 // +---+ | - 522 // |Old| v | DstRange 523 // | | DstN | 524 // +---+ - 525 // When scanning for deps with destination in DAGInterval we need to 526 // consider sources from the NewInterval only, because all intra-DAGInterval 527 // dependencies have already been created. 528 auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this); 529 auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 530 for (MemDGNode &DstN : DstRangeOld) 531 scanAndAddDeps(DstN, SrcRange); 532 } else { 533 llvm_unreachable("We don't expect extending in both directions!"); 534 } 535 536 DAGInterval = Union; 537 return NewInterval; 538 } 539 540 #ifndef NDEBUG 541 void DependencyGraph::print(raw_ostream &OS) const { 542 // InstrToNodeMap is unordered so we need to create an ordered vector. 543 SmallVector<DGNode *> Nodes; 544 Nodes.reserve(InstrToNodeMap.size()); 545 for (const auto &Pair : InstrToNodeMap) 546 Nodes.push_back(Pair.second.get()); 547 // Sort them based on which one comes first in the BB. 548 sort(Nodes, [](DGNode *N1, DGNode *N2) { 549 return N1->getInstruction()->comesBefore(N2->getInstruction()); 550 }); 551 for (auto *N : Nodes) 552 N->print(OS, /*PrintDeps=*/true); 553 } 554 555 void DependencyGraph::dump() const { 556 print(dbgs()); 557 dbgs() << "\n"; 558 } 559 #endif // NDEBUG 560 561 } // namespace llvm::sandboxir 562