1 //===- DependencyGraph.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/SandboxIR/Instruction.h" 12 #include "llvm/SandboxIR/Utils.h" 13 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" 14 15 namespace llvm::sandboxir { 16 17 PredIterator::value_type PredIterator::operator*() { 18 // If it's a DGNode then we dereference the operand iterator. 19 if (!isa<MemDGNode>(N)) { 20 assert(OpIt != OpItE && "Can't dereference end iterator!"); 21 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 22 } 23 // It's a MemDGNode, so we check if we return either the use-def operand, 24 // or a mem predecessor. 25 if (OpIt != OpItE) 26 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 27 // It's a MemDGNode with OpIt == end, so we need to use MemIt. 28 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && 29 "Cant' dereference end iterator!"); 30 return *MemIt; 31 } 32 33 PredIterator &PredIterator::operator++() { 34 // If it's a DGNode then we increment the use-def iterator. 35 if (!isa<MemDGNode>(N)) { 36 assert(OpIt != OpItE && "Already at end!"); 37 ++OpIt; 38 // Skip operands that are not instructions. 39 OpIt = skipNonInstr(OpIt, OpItE); 40 return *this; 41 } 42 // It's a MemDGNode, so if we are not at the end of the use-def iterator we 43 // need to first increment that. 44 if (OpIt != OpItE) { 45 ++OpIt; 46 // Skip operands that are not instructions. 47 OpIt = skipNonInstr(OpIt, OpItE); 48 return *this; 49 } 50 // It's a MemDGNode with OpIt == end, so we need to increment MemIt. 51 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!"); 52 ++MemIt; 53 return *this; 54 } 55 56 bool PredIterator::operator==(const PredIterator &Other) const { 57 assert(DAG == Other.DAG && "Iterators of different DAGs!"); 58 assert(N == Other.N && "Iterators of different nodes!"); 59 return OpIt == Other.OpIt && MemIt == Other.MemIt; 60 } 61 62 DGNode::~DGNode() { 63 if (SB == nullptr) 64 return; 65 SB->eraseFromBundle(this); 66 } 67 68 #ifndef NDEBUG 69 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { 70 OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n"; 71 } 72 void DGNode::dump() const { print(dbgs()); } 73 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { 74 DGNode::print(OS, false); 75 if (PrintDeps) { 76 // Print memory preds. 77 static constexpr const unsigned Indent = 4; 78 for (auto *Pred : MemPreds) 79 OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n"; 80 } 81 } 82 #endif // NDEBUG 83 84 MemDGNode * 85 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, 86 const DependencyGraph &DAG) { 87 Instruction *I = Intvl.top(); 88 Instruction *BeforeI = Intvl.bottom(); 89 // Walk down the chain looking for a mem-dep candidate instruction. 90 while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) 91 I = I->getNextNode(); 92 if (!DGNode::isMemDepNodeCandidate(I)) 93 return nullptr; 94 return cast<MemDGNode>(DAG.getNode(I)); 95 } 96 97 MemDGNode * 98 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, 99 const DependencyGraph &DAG) { 100 Instruction *I = Intvl.bottom(); 101 Instruction *AfterI = Intvl.top(); 102 // Walk up the chain looking for a mem-dep candidate instruction. 103 while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) 104 I = I->getPrevNode(); 105 if (!DGNode::isMemDepNodeCandidate(I)) 106 return nullptr; 107 return cast<MemDGNode>(DAG.getNode(I)); 108 } 109 110 Interval<MemDGNode> 111 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, 112 DependencyGraph &DAG) { 113 auto *TopMemN = getTopMemDGNode(Instrs, DAG); 114 // If we couldn't find a mem node in range TopN - BotN then it's empty. 115 if (TopMemN == nullptr) 116 return {}; 117 auto *BotMemN = getBotMemDGNode(Instrs, DAG); 118 assert(BotMemN != nullptr && "TopMemN should be null too!"); 119 // Now that we have the mem-dep nodes, create and return the range. 120 return Interval<MemDGNode>(TopMemN, BotMemN); 121 } 122 123 DependencyGraph::DependencyType 124 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { 125 // TODO: Perhaps compile-time improvement by skipping if neither is mem? 126 if (FromI->mayWriteToMemory()) { 127 if (ToI->mayReadFromMemory()) 128 return DependencyType::ReadAfterWrite; 129 if (ToI->mayWriteToMemory()) 130 return DependencyType::WriteAfterWrite; 131 } else if (FromI->mayReadFromMemory()) { 132 if (ToI->mayWriteToMemory()) 133 return DependencyType::WriteAfterRead; 134 } 135 if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI)) 136 return DependencyType::Control; 137 if (ToI->isTerminator()) 138 return DependencyType::Control; 139 if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) || 140 DGNode::isStackSaveOrRestoreIntrinsic(ToI)) 141 return DependencyType::Other; 142 return DependencyType::None; 143 } 144 145 static bool isOrdered(Instruction *I) { 146 auto IsOrdered = [](Instruction *I) { 147 if (auto *LI = dyn_cast<LoadInst>(I)) 148 return !LI->isUnordered(); 149 if (auto *SI = dyn_cast<StoreInst>(I)) 150 return !SI->isUnordered(); 151 if (DGNode::isFenceLike(I)) 152 return true; 153 return false; 154 }; 155 bool Is = IsOrdered(I); 156 assert((!Is || DGNode::isMemDepCandidate(I)) && 157 "An ordered instruction must be a MemDepCandidate!"); 158 return Is; 159 } 160 161 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, 162 DependencyType DepType) { 163 std::optional<MemoryLocation> DstLocOpt = 164 Utils::memoryLocationGetOrNone(DstI); 165 if (!DstLocOpt) 166 return true; 167 // Check aliasing. 168 assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && 169 "Expected a mem instr"); 170 // TODO: Check AABudget 171 ModRefInfo SrcModRef = 172 isOrdered(SrcI) 173 ? ModRefInfo::ModRef 174 : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt); 175 switch (DepType) { 176 case DependencyType::ReadAfterWrite: 177 case DependencyType::WriteAfterWrite: 178 return isModSet(SrcModRef); 179 case DependencyType::WriteAfterRead: 180 return isRefSet(SrcModRef); 181 default: 182 llvm_unreachable("Expected only RAW, WAW and WAR!"); 183 } 184 } 185 186 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { 187 DependencyType RoughDepType = getRoughDepType(SrcI, DstI); 188 switch (RoughDepType) { 189 case DependencyType::ReadAfterWrite: 190 case DependencyType::WriteAfterWrite: 191 case DependencyType::WriteAfterRead: 192 return alias(SrcI, DstI, RoughDepType); 193 case DependencyType::Control: 194 // Adding actual dep edges from PHIs/to terminator would just create too 195 // many edges, which would be bad for compile-time. 196 // So we ignore them in the DAG formation but handle them in the 197 // scheduler, while sorting the ready list. 198 return false; 199 case DependencyType::Other: 200 return true; 201 case DependencyType::None: 202 return false; 203 } 204 llvm_unreachable("Unknown DependencyType enum"); 205 } 206 207 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, 208 const Interval<MemDGNode> &SrcScanRange) { 209 assert(isa<MemDGNode>(DstN) && 210 "DstN is the mem dep destination, so it must be mem"); 211 Instruction *DstI = DstN.getInstruction(); 212 // Walk up the instruction chain from ScanRange bottom to top, looking for 213 // memory instrs that may alias. 214 for (MemDGNode &SrcN : reverse(SrcScanRange)) { 215 Instruction *SrcI = SrcN.getInstruction(); 216 if (hasDep(SrcI, DstI)) 217 DstN.addMemPred(&SrcN); 218 } 219 } 220 221 void DependencyGraph::setDefUseUnscheduledSuccs( 222 const Interval<Instruction> &NewInterval) { 223 // +---+ 224 // | | Def 225 // | | | 226 // | | v 227 // | | Use 228 // +---+ 229 // Set the intra-interval counters in NewInterval. 230 for (Instruction &I : NewInterval) { 231 for (Value *Op : I.operands()) { 232 auto *OpI = dyn_cast<Instruction>(Op); 233 if (OpI == nullptr) 234 continue; 235 if (!NewInterval.contains(OpI)) 236 continue; 237 auto *OpN = getNode(OpI); 238 if (OpN == nullptr) 239 continue; 240 ++OpN->UnscheduledSuccs; 241 } 242 } 243 244 // Now handle the cross-interval edges. 245 bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval); 246 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 247 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 248 // +---+ 249 // |Top| 250 // | | Def 251 // +---+ | 252 // | | v 253 // |Bot| Use 254 // | | 255 // +---+ 256 // Walk over all instructions in "BotInterval" and update the counter 257 // of operands that are in "TopInterval". 258 for (Instruction &BotI : BotInterval) { 259 auto *BotN = getNode(&BotI); 260 // Skip scheduled nodes. 261 if (BotN->scheduled()) 262 continue; 263 for (Value *Op : BotI.operands()) { 264 auto *OpI = dyn_cast<Instruction>(Op); 265 if (OpI == nullptr) 266 continue; 267 if (!TopInterval.contains(OpI)) 268 continue; 269 auto *OpN = getNode(OpI); 270 if (OpN == nullptr) 271 continue; 272 ++OpN->UnscheduledSuccs; 273 } 274 } 275 } 276 277 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { 278 // Create Nodes only for the new sections of the DAG. 279 DGNode *LastN = getOrCreateNode(NewInterval.top()); 280 MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN); 281 for (Instruction &I : drop_begin(NewInterval)) { 282 auto *N = getOrCreateNode(&I); 283 // Build the Mem node chain. 284 if (auto *MemN = dyn_cast<MemDGNode>(N)) { 285 MemN->setPrevNode(LastMemN); 286 if (LastMemN != nullptr) 287 LastMemN->setNextNode(MemN); 288 LastMemN = MemN; 289 } 290 } 291 // Link new MemDGNode chain with the old one, if any. 292 if (!DAGInterval.empty()) { 293 bool NewIsAbove = NewInterval.comesBefore(DAGInterval); 294 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 295 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 296 MemDGNode *LinkTopN = 297 MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this); 298 MemDGNode *LinkBotN = 299 MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this); 300 assert((LinkTopN == nullptr || LinkBotN == nullptr || 301 LinkTopN->comesBefore(LinkBotN)) && 302 "Wrong order!"); 303 if (LinkTopN != nullptr && LinkBotN != nullptr) { 304 LinkTopN->setNextNode(LinkBotN); 305 LinkBotN->setPrevNode(LinkTopN); 306 } 307 #ifndef NDEBUG 308 // TODO: Remove this once we've done enough testing. 309 // Check that the chain is well formed. 310 auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); 311 MemDGNode *ChainTopN = 312 MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); 313 MemDGNode *ChainBotN = 314 MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); 315 if (ChainTopN != nullptr && ChainBotN != nullptr) { 316 for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; 317 LastN = N, N = N->getNextNode()) { 318 assert(N == LastN->getNextNode() && "Bad chain!"); 319 assert(N->getPrevNode() == LastN && "Bad chain!"); 320 } 321 } 322 #endif // NDEBUG 323 } 324 325 setDefUseUnscheduledSuccs(NewInterval); 326 } 327 328 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { 329 if (Instrs.empty()) 330 return {}; 331 332 Interval<Instruction> InstrsInterval(Instrs); 333 Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval); 334 auto NewInterval = Union.getSingleDiff(DAGInterval); 335 if (NewInterval.empty()) 336 return {}; 337 338 createNewNodes(NewInterval); 339 340 // Create the dependencies. 341 // 342 // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. 343 // +---+ - - 344 // | | SrcN | | 345 // | | | | SrcRange | 346 // |New| v | | DstRange 347 // | | DstN - | 348 // | | | 349 // +---+ - 350 // We are scanning for deps with destination in NewInterval and sources in 351 // NewInterval until DstN, for each DstN. 352 auto FullScan = [this](const Interval<Instruction> Intvl) { 353 auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this); 354 if (!DstRange.empty()) { 355 for (MemDGNode &DstN : drop_begin(DstRange)) { 356 auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); 357 scanAndAddDeps(DstN, SrcRange); 358 } 359 } 360 }; 361 if (DAGInterval.empty()) { 362 assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); 363 FullScan(NewInterval); 364 } 365 // 2. The new section is below the old section. 366 // +---+ - 367 // | | | 368 // |Old| SrcN | 369 // | | | | 370 // +---+ | | SrcRange 371 // +---+ | | - 372 // | | | | | 373 // |New| v | | DstRange 374 // | | DstN - | 375 // | | | 376 // +---+ - 377 // We are scanning for deps with destination in NewInterval because the deps 378 // in DAGInterval have already been computed. We consider sources in the whole 379 // range including both NewInterval and DAGInterval until DstN, for each DstN. 380 else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { 381 auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 382 auto SrcRangeFull = MemDGNodeIntervalBuilder::make( 383 DAGInterval.getUnionInterval(NewInterval), *this); 384 for (MemDGNode &DstN : DstRange) { 385 auto SrcRange = 386 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); 387 scanAndAddDeps(DstN, SrcRange); 388 } 389 } 390 // 3. The new section is above the old section. 391 else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) { 392 // +---+ - - 393 // | | SrcN | | 394 // |New| | | SrcRange | DstRange 395 // | | v | | 396 // | | DstN - | 397 // | | | 398 // +---+ - 399 // +---+ 400 // |Old| 401 // | | 402 // +---+ 403 // When scanning for deps with destination in NewInterval we need to fully 404 // scan the interval. This is the same as the scanning for a new DAG. 405 FullScan(NewInterval); 406 407 // +---+ - 408 // | | | 409 // |New| SrcN | SrcRange 410 // | | | | 411 // | | | | 412 // | | | | 413 // +---+ | - 414 // +---+ | - 415 // |Old| v | DstRange 416 // | | DstN | 417 // +---+ - 418 // When scanning for deps with destination in DAGInterval we need to 419 // consider sources from the NewInterval only, because all intra-DAGInterval 420 // dependencies have already been created. 421 auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this); 422 auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 423 for (MemDGNode &DstN : DstRangeOld) 424 scanAndAddDeps(DstN, SrcRange); 425 } else { 426 llvm_unreachable("We don't expect extending in both directions!"); 427 } 428 429 DAGInterval = Union; 430 return NewInterval; 431 } 432 433 #ifndef NDEBUG 434 void DependencyGraph::print(raw_ostream &OS) const { 435 // InstrToNodeMap is unordered so we need to create an ordered vector. 436 SmallVector<DGNode *> Nodes; 437 Nodes.reserve(InstrToNodeMap.size()); 438 for (const auto &Pair : InstrToNodeMap) 439 Nodes.push_back(Pair.second.get()); 440 // Sort them based on which one comes first in the BB. 441 sort(Nodes, [](DGNode *N1, DGNode *N2) { 442 return N1->getInstruction()->comesBefore(N2->getInstruction()); 443 }); 444 for (auto *N : Nodes) 445 N->print(OS, /*PrintDeps=*/true); 446 } 447 448 void DependencyGraph::dump() const { 449 print(dbgs()); 450 dbgs() << "\n"; 451 } 452 #endif // NDEBUG 453 454 } // namespace llvm::sandboxir 455