1 //===- DependencyGraph.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/SandboxIR/Instruction.h" 12 #include "llvm/SandboxIR/Utils.h" 13 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" 14 15 namespace llvm::sandboxir { 16 17 PredIterator::value_type PredIterator::operator*() { 18 // If it's a DGNode then we dereference the operand iterator. 19 if (!isa<MemDGNode>(N)) { 20 assert(OpIt != OpItE && "Can't dereference end iterator!"); 21 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 22 } 23 // It's a MemDGNode, so we check if we return either the use-def operand, 24 // or a mem predecessor. 25 if (OpIt != OpItE) 26 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 27 // It's a MemDGNode with OpIt == end, so we need to use MemIt. 28 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && 29 "Cant' dereference end iterator!"); 30 return *MemIt; 31 } 32 33 PredIterator &PredIterator::operator++() { 34 // If it's a DGNode then we increment the use-def iterator. 35 if (!isa<MemDGNode>(N)) { 36 assert(OpIt != OpItE && "Already at end!"); 37 ++OpIt; 38 // Skip operands that are not instructions. 39 OpIt = skipNonInstr(OpIt, OpItE); 40 return *this; 41 } 42 // It's a MemDGNode, so if we are not at the end of the use-def iterator we 43 // need to first increment that. 44 if (OpIt != OpItE) { 45 ++OpIt; 46 // Skip operands that are not instructions. 47 OpIt = skipNonInstr(OpIt, OpItE); 48 return *this; 49 } 50 // It's a MemDGNode with OpIt == end, so we need to increment MemIt. 51 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!"); 52 ++MemIt; 53 return *this; 54 } 55 56 bool PredIterator::operator==(const PredIterator &Other) const { 57 assert(DAG == Other.DAG && "Iterators of different DAGs!"); 58 assert(N == Other.N && "Iterators of different nodes!"); 59 return OpIt == Other.OpIt && MemIt == Other.MemIt; 60 } 61 62 DGNode::~DGNode() { 63 if (SB == nullptr) 64 return; 65 SB->eraseFromBundle(this); 66 } 67 68 #ifndef NDEBUG 69 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { 70 OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n"; 71 } 72 void DGNode::dump() const { print(dbgs()); } 73 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { 74 DGNode::print(OS, false); 75 if (PrintDeps) { 76 // Print memory preds. 77 static constexpr const unsigned Indent = 4; 78 for (auto *Pred : MemPreds) 79 OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n"; 80 } 81 } 82 #endif // NDEBUG 83 84 MemDGNode * 85 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, 86 const DependencyGraph &DAG) { 87 Instruction *I = Intvl.top(); 88 Instruction *BeforeI = Intvl.bottom(); 89 // Walk down the chain looking for a mem-dep candidate instruction. 90 while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) 91 I = I->getNextNode(); 92 if (!DGNode::isMemDepNodeCandidate(I)) 93 return nullptr; 94 return cast<MemDGNode>(DAG.getNode(I)); 95 } 96 97 MemDGNode * 98 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, 99 const DependencyGraph &DAG) { 100 Instruction *I = Intvl.bottom(); 101 Instruction *AfterI = Intvl.top(); 102 // Walk up the chain looking for a mem-dep candidate instruction. 103 while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) 104 I = I->getPrevNode(); 105 if (!DGNode::isMemDepNodeCandidate(I)) 106 return nullptr; 107 return cast<MemDGNode>(DAG.getNode(I)); 108 } 109 110 Interval<MemDGNode> 111 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, 112 DependencyGraph &DAG) { 113 auto *TopMemN = getTopMemDGNode(Instrs, DAG); 114 // If we couldn't find a mem node in range TopN - BotN then it's empty. 115 if (TopMemN == nullptr) 116 return {}; 117 auto *BotMemN = getBotMemDGNode(Instrs, DAG); 118 assert(BotMemN != nullptr && "TopMemN should be null too!"); 119 // Now that we have the mem-dep nodes, create and return the range. 120 return Interval<MemDGNode>(TopMemN, BotMemN); 121 } 122 123 DependencyGraph::DependencyType 124 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { 125 // TODO: Perhaps compile-time improvement by skipping if neither is mem? 126 if (FromI->mayWriteToMemory()) { 127 if (ToI->mayReadFromMemory()) 128 return DependencyType::ReadAfterWrite; 129 if (ToI->mayWriteToMemory()) 130 return DependencyType::WriteAfterWrite; 131 } else if (FromI->mayReadFromMemory()) { 132 if (ToI->mayWriteToMemory()) 133 return DependencyType::WriteAfterRead; 134 } 135 if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI)) 136 return DependencyType::Control; 137 if (ToI->isTerminator()) 138 return DependencyType::Control; 139 if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) || 140 DGNode::isStackSaveOrRestoreIntrinsic(ToI)) 141 return DependencyType::Other; 142 return DependencyType::None; 143 } 144 145 static bool isOrdered(Instruction *I) { 146 auto IsOrdered = [](Instruction *I) { 147 if (auto *LI = dyn_cast<LoadInst>(I)) 148 return !LI->isUnordered(); 149 if (auto *SI = dyn_cast<StoreInst>(I)) 150 return !SI->isUnordered(); 151 if (DGNode::isFenceLike(I)) 152 return true; 153 return false; 154 }; 155 bool Is = IsOrdered(I); 156 assert((!Is || DGNode::isMemDepCandidate(I)) && 157 "An ordered instruction must be a MemDepCandidate!"); 158 return Is; 159 } 160 161 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, 162 DependencyType DepType) { 163 std::optional<MemoryLocation> DstLocOpt = 164 Utils::memoryLocationGetOrNone(DstI); 165 if (!DstLocOpt) 166 return true; 167 // Check aliasing. 168 assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && 169 "Expected a mem instr"); 170 // TODO: Check AABudget 171 ModRefInfo SrcModRef = 172 isOrdered(SrcI) 173 ? ModRefInfo::ModRef 174 : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt); 175 switch (DepType) { 176 case DependencyType::ReadAfterWrite: 177 case DependencyType::WriteAfterWrite: 178 return isModSet(SrcModRef); 179 case DependencyType::WriteAfterRead: 180 return isRefSet(SrcModRef); 181 default: 182 llvm_unreachable("Expected only RAW, WAW and WAR!"); 183 } 184 } 185 186 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { 187 DependencyType RoughDepType = getRoughDepType(SrcI, DstI); 188 switch (RoughDepType) { 189 case DependencyType::ReadAfterWrite: 190 case DependencyType::WriteAfterWrite: 191 case DependencyType::WriteAfterRead: 192 return alias(SrcI, DstI, RoughDepType); 193 case DependencyType::Control: 194 // Adding actual dep edges from PHIs/to terminator would just create too 195 // many edges, which would be bad for compile-time. 196 // So we ignore them in the DAG formation but handle them in the 197 // scheduler, while sorting the ready list. 198 return false; 199 case DependencyType::Other: 200 return true; 201 case DependencyType::None: 202 return false; 203 } 204 llvm_unreachable("Unknown DependencyType enum"); 205 } 206 207 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, 208 const Interval<MemDGNode> &SrcScanRange) { 209 assert(isa<MemDGNode>(DstN) && 210 "DstN is the mem dep destination, so it must be mem"); 211 Instruction *DstI = DstN.getInstruction(); 212 // Walk up the instruction chain from ScanRange bottom to top, looking for 213 // memory instrs that may alias. 214 for (MemDGNode &SrcN : reverse(SrcScanRange)) { 215 Instruction *SrcI = SrcN.getInstruction(); 216 if (hasDep(SrcI, DstI)) 217 DstN.addMemPred(&SrcN); 218 } 219 } 220 221 void DependencyGraph::setDefUseUnscheduledSuccs( 222 const Interval<Instruction> &NewInterval) { 223 // +---+ 224 // | | Def 225 // | | | 226 // | | v 227 // | | Use 228 // +---+ 229 // Set the intra-interval counters in NewInterval. 230 for (Instruction &I : NewInterval) { 231 for (Value *Op : I.operands()) { 232 auto *OpI = dyn_cast<Instruction>(Op); 233 if (OpI == nullptr) 234 continue; 235 // TODO: For now don't cross BBs. 236 if (OpI->getParent() != I.getParent()) 237 continue; 238 if (!NewInterval.contains(OpI)) 239 continue; 240 auto *OpN = getNode(OpI); 241 if (OpN == nullptr) 242 continue; 243 ++OpN->UnscheduledSuccs; 244 } 245 } 246 247 // Now handle the cross-interval edges. 248 bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval); 249 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 250 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 251 // +---+ 252 // |Top| 253 // | | Def 254 // +---+ | 255 // | | v 256 // |Bot| Use 257 // | | 258 // +---+ 259 // Walk over all instructions in "BotInterval" and update the counter 260 // of operands that are in "TopInterval". 261 for (Instruction &BotI : BotInterval) { 262 auto *BotN = getNode(&BotI); 263 // Skip scheduled nodes. 264 if (BotN->scheduled()) 265 continue; 266 for (Value *Op : BotI.operands()) { 267 auto *OpI = dyn_cast<Instruction>(Op); 268 if (OpI == nullptr) 269 continue; 270 if (!TopInterval.contains(OpI)) 271 continue; 272 auto *OpN = getNode(OpI); 273 if (OpN == nullptr) 274 continue; 275 ++OpN->UnscheduledSuccs; 276 } 277 } 278 } 279 280 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { 281 // Create Nodes only for the new sections of the DAG. 282 DGNode *LastN = getOrCreateNode(NewInterval.top()); 283 MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN); 284 for (Instruction &I : drop_begin(NewInterval)) { 285 auto *N = getOrCreateNode(&I); 286 // Build the Mem node chain. 287 if (auto *MemN = dyn_cast<MemDGNode>(N)) { 288 MemN->setPrevNode(LastMemN); 289 LastMemN = MemN; 290 } 291 } 292 // Link new MemDGNode chain with the old one, if any. 293 if (!DAGInterval.empty()) { 294 bool NewIsAbove = NewInterval.comesBefore(DAGInterval); 295 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 296 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 297 MemDGNode *LinkTopN = 298 MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this); 299 MemDGNode *LinkBotN = 300 MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this); 301 assert((LinkTopN == nullptr || LinkBotN == nullptr || 302 LinkTopN->comesBefore(LinkBotN)) && 303 "Wrong order!"); 304 if (LinkTopN != nullptr && LinkBotN != nullptr) { 305 LinkTopN->setNextNode(LinkBotN); 306 } 307 #ifndef NDEBUG 308 // TODO: Remove this once we've done enough testing. 309 // Check that the chain is well formed. 310 auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); 311 MemDGNode *ChainTopN = 312 MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); 313 MemDGNode *ChainBotN = 314 MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); 315 if (ChainTopN != nullptr && ChainBotN != nullptr) { 316 for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; 317 LastN = N, N = N->getNextNode()) { 318 assert(N == LastN->getNextNode() && "Bad chain!"); 319 assert(N->getPrevNode() == LastN && "Bad chain!"); 320 } 321 } 322 #endif // NDEBUG 323 } 324 325 setDefUseUnscheduledSuccs(NewInterval); 326 } 327 328 MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, 329 bool IncludingN) const { 330 auto *I = N->getInstruction(); 331 for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr; 332 PrevI = PrevI->getPrevNode()) { 333 auto *PrevN = getNodeOrNull(PrevI); 334 if (PrevN == nullptr) 335 return nullptr; 336 if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN)) 337 return PrevMemN; 338 } 339 return nullptr; 340 } 341 342 MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, 343 bool IncludingN) const { 344 auto *I = N->getInstruction(); 345 for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr; 346 NextI = NextI->getNextNode()) { 347 auto *NextN = getNodeOrNull(NextI); 348 if (NextN == nullptr) 349 return nullptr; 350 if (auto *NextMemN = dyn_cast<MemDGNode>(NextN)) 351 return NextMemN; 352 } 353 return nullptr; 354 } 355 356 void DependencyGraph::notifyCreateInstr(Instruction *I) { 357 auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I)); 358 // TODO: Update the dependencies for the new node. 359 360 // Update the MemDGNode chain if this is a memory node. 361 if (MemN != nullptr) { 362 if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) { 363 PrevMemN->NextMemN = MemN; 364 MemN->PrevMemN = PrevMemN; 365 } 366 if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) { 367 NextMemN->PrevMemN = MemN; 368 MemN->NextMemN = NextMemN; 369 } 370 } 371 } 372 373 void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) { 374 // NOTE: This function runs before `I` moves to its new destination. 375 BasicBlock *BB = To.getNodeParent(); 376 assert(!(To != BB->end() && &*To == I->getNextNode()) && 377 !(To == BB->end() && std::next(I->getIterator()) == BB->end()) && 378 "Should not have been called if destination is same as origin."); 379 380 // Maintain the DAGInterval. 381 DAGInterval.notifyMoveInstr(I, To); 382 383 // TODO: Perhaps check if this is legal by checking the dependencies? 384 385 // Update the MemDGNode chain to reflect the instr movement if necessary. 386 DGNode *N = getNodeOrNull(I); 387 if (N == nullptr) 388 return; 389 MemDGNode *MemN = dyn_cast<MemDGNode>(N); 390 if (MemN == nullptr) 391 return; 392 // First detach it from the existing chain. 393 MemN->detachFromChain(); 394 // Now insert it back into the chain at the new location. 395 if (To != BB->end()) { 396 DGNode *ToN = getNodeOrNull(&*To); 397 if (ToN != nullptr) { 398 MemN->setPrevNode(getMemDGNodeBefore(ToN, /*IncludingN=*/false)); 399 MemN->setNextNode(getMemDGNodeAfter(ToN, /*IncludingN=*/true)); 400 } 401 } else { 402 // MemN becomes the last instruction in the BB. 403 auto *TermN = getNodeOrNull(BB->getTerminator()); 404 if (TermN != nullptr) { 405 MemN->setPrevNode(getMemDGNodeBefore(TermN, /*IncludingN=*/false)); 406 } else { 407 // The terminator is outside the DAG interval so do nothing. 408 } 409 } 410 } 411 412 void DependencyGraph::notifyEraseInstr(Instruction *I) { 413 // Update the MemDGNode chain if this is a memory node. 414 if (auto *MemN = dyn_cast_or_null<MemDGNode>(getNodeOrNull(I))) { 415 auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false); 416 auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false); 417 if (PrevMemN != nullptr) 418 PrevMemN->NextMemN = NextMemN; 419 if (NextMemN != nullptr) 420 NextMemN->PrevMemN = PrevMemN; 421 } 422 423 InstrToNodeMap.erase(I); 424 425 // TODO: Update the dependencies. 426 } 427 428 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { 429 if (Instrs.empty()) 430 return {}; 431 432 Interval<Instruction> InstrsInterval(Instrs); 433 Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval); 434 auto NewInterval = Union.getSingleDiff(DAGInterval); 435 if (NewInterval.empty()) 436 return {}; 437 438 createNewNodes(NewInterval); 439 440 // Create the dependencies. 441 // 442 // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. 443 // +---+ - - 444 // | | SrcN | | 445 // | | | | SrcRange | 446 // |New| v | | DstRange 447 // | | DstN - | 448 // | | | 449 // +---+ - 450 // We are scanning for deps with destination in NewInterval and sources in 451 // NewInterval until DstN, for each DstN. 452 auto FullScan = [this](const Interval<Instruction> Intvl) { 453 auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this); 454 if (!DstRange.empty()) { 455 for (MemDGNode &DstN : drop_begin(DstRange)) { 456 auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); 457 scanAndAddDeps(DstN, SrcRange); 458 } 459 } 460 }; 461 if (DAGInterval.empty()) { 462 assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); 463 FullScan(NewInterval); 464 } 465 // 2. The new section is below the old section. 466 // +---+ - 467 // | | | 468 // |Old| SrcN | 469 // | | | | 470 // +---+ | | SrcRange 471 // +---+ | | - 472 // | | | | | 473 // |New| v | | DstRange 474 // | | DstN - | 475 // | | | 476 // +---+ - 477 // We are scanning for deps with destination in NewInterval because the deps 478 // in DAGInterval have already been computed. We consider sources in the whole 479 // range including both NewInterval and DAGInterval until DstN, for each DstN. 480 else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { 481 auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 482 auto SrcRangeFull = MemDGNodeIntervalBuilder::make( 483 DAGInterval.getUnionInterval(NewInterval), *this); 484 for (MemDGNode &DstN : DstRange) { 485 auto SrcRange = 486 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); 487 scanAndAddDeps(DstN, SrcRange); 488 } 489 } 490 // 3. The new section is above the old section. 491 else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) { 492 // +---+ - - 493 // | | SrcN | | 494 // |New| | | SrcRange | DstRange 495 // | | v | | 496 // | | DstN - | 497 // | | | 498 // +---+ - 499 // +---+ 500 // |Old| 501 // | | 502 // +---+ 503 // When scanning for deps with destination in NewInterval we need to fully 504 // scan the interval. This is the same as the scanning for a new DAG. 505 FullScan(NewInterval); 506 507 // +---+ - 508 // | | | 509 // |New| SrcN | SrcRange 510 // | | | | 511 // | | | | 512 // | | | | 513 // +---+ | - 514 // +---+ | - 515 // |Old| v | DstRange 516 // | | DstN | 517 // +---+ - 518 // When scanning for deps with destination in DAGInterval we need to 519 // consider sources from the NewInterval only, because all intra-DAGInterval 520 // dependencies have already been created. 521 auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this); 522 auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 523 for (MemDGNode &DstN : DstRangeOld) 524 scanAndAddDeps(DstN, SrcRange); 525 } else { 526 llvm_unreachable("We don't expect extending in both directions!"); 527 } 528 529 DAGInterval = Union; 530 return NewInterval; 531 } 532 533 #ifndef NDEBUG 534 void DependencyGraph::print(raw_ostream &OS) const { 535 // InstrToNodeMap is unordered so we need to create an ordered vector. 536 SmallVector<DGNode *> Nodes; 537 Nodes.reserve(InstrToNodeMap.size()); 538 for (const auto &Pair : InstrToNodeMap) 539 Nodes.push_back(Pair.second.get()); 540 // Sort them based on which one comes first in the BB. 541 sort(Nodes, [](DGNode *N1, DGNode *N2) { 542 return N1->getInstruction()->comesBefore(N2->getInstruction()); 543 }); 544 for (auto *N : Nodes) 545 N->print(OS, /*PrintDeps=*/true); 546 } 547 548 void DependencyGraph::dump() const { 549 print(dbgs()); 550 dbgs() << "\n"; 551 } 552 #endif // NDEBUG 553 554 } // namespace llvm::sandboxir 555