1 //===- DependencyGraph.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/SandboxIR/Instruction.h" 12 #include "llvm/SandboxIR/Utils.h" 13 14 namespace llvm::sandboxir { 15 16 PredIterator::value_type PredIterator::operator*() { 17 // If it's a DGNode then we dereference the operand iterator. 18 if (!isa<MemDGNode>(N)) { 19 assert(OpIt != OpItE && "Can't dereference end iterator!"); 20 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 21 } 22 // It's a MemDGNode, so we check if we return either the use-def operand, 23 // or a mem predecessor. 24 if (OpIt != OpItE) 25 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 26 // It's a MemDGNode with OpIt == end, so we need to use MemIt. 27 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && 28 "Cant' dereference end iterator!"); 29 return *MemIt; 30 } 31 32 PredIterator &PredIterator::operator++() { 33 // If it's a DGNode then we increment the use-def iterator. 34 if (!isa<MemDGNode>(N)) { 35 assert(OpIt != OpItE && "Already at end!"); 36 ++OpIt; 37 // Skip operands that are not instructions. 38 OpIt = skipNonInstr(OpIt, OpItE); 39 return *this; 40 } 41 // It's a MemDGNode, so if we are not at the end of the use-def iterator we 42 // need to first increment that. 43 if (OpIt != OpItE) { 44 ++OpIt; 45 // Skip operands that are not instructions. 46 OpIt = skipNonInstr(OpIt, OpItE); 47 return *this; 48 } 49 // It's a MemDGNode with OpIt == end, so we need to increment MemIt. 50 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!"); 51 ++MemIt; 52 return *this; 53 } 54 55 bool PredIterator::operator==(const PredIterator &Other) const { 56 assert(DAG == Other.DAG && "Iterators of different DAGs!"); 57 assert(N == Other.N && "Iterators of different nodes!"); 58 return OpIt == Other.OpIt && MemIt == Other.MemIt; 59 } 60 61 #ifndef NDEBUG 62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { 63 OS << *I << " USuccs:" << UnscheduledSuccs << "\n"; 64 } 65 void DGNode::dump() const { print(dbgs()); } 66 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { 67 DGNode::print(OS, false); 68 if (PrintDeps) { 69 // Print memory preds. 70 static constexpr const unsigned Indent = 4; 71 for (auto *Pred : MemPreds) 72 OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n"; 73 } 74 } 75 #endif // NDEBUG 76 77 MemDGNode * 78 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, 79 const DependencyGraph &DAG) { 80 Instruction *I = Intvl.top(); 81 Instruction *BeforeI = Intvl.bottom(); 82 // Walk down the chain looking for a mem-dep candidate instruction. 83 while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) 84 I = I->getNextNode(); 85 if (!DGNode::isMemDepNodeCandidate(I)) 86 return nullptr; 87 return cast<MemDGNode>(DAG.getNode(I)); 88 } 89 90 MemDGNode * 91 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, 92 const DependencyGraph &DAG) { 93 Instruction *I = Intvl.bottom(); 94 Instruction *AfterI = Intvl.top(); 95 // Walk up the chain looking for a mem-dep candidate instruction. 96 while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) 97 I = I->getPrevNode(); 98 if (!DGNode::isMemDepNodeCandidate(I)) 99 return nullptr; 100 return cast<MemDGNode>(DAG.getNode(I)); 101 } 102 103 Interval<MemDGNode> 104 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, 105 DependencyGraph &DAG) { 106 auto *TopMemN = getTopMemDGNode(Instrs, DAG); 107 // If we couldn't find a mem node in range TopN - BotN then it's empty. 108 if (TopMemN == nullptr) 109 return {}; 110 auto *BotMemN = getBotMemDGNode(Instrs, DAG); 111 assert(BotMemN != nullptr && "TopMemN should be null too!"); 112 // Now that we have the mem-dep nodes, create and return the range. 113 return Interval<MemDGNode>(TopMemN, BotMemN); 114 } 115 116 DependencyGraph::DependencyType 117 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { 118 // TODO: Perhaps compile-time improvement by skipping if neither is mem? 119 if (FromI->mayWriteToMemory()) { 120 if (ToI->mayReadFromMemory()) 121 return DependencyType::ReadAfterWrite; 122 if (ToI->mayWriteToMemory()) 123 return DependencyType::WriteAfterWrite; 124 } else if (FromI->mayReadFromMemory()) { 125 if (ToI->mayWriteToMemory()) 126 return DependencyType::WriteAfterRead; 127 } 128 if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI)) 129 return DependencyType::Control; 130 if (ToI->isTerminator()) 131 return DependencyType::Control; 132 if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) || 133 DGNode::isStackSaveOrRestoreIntrinsic(ToI)) 134 return DependencyType::Other; 135 return DependencyType::None; 136 } 137 138 static bool isOrdered(Instruction *I) { 139 auto IsOrdered = [](Instruction *I) { 140 if (auto *LI = dyn_cast<LoadInst>(I)) 141 return !LI->isUnordered(); 142 if (auto *SI = dyn_cast<StoreInst>(I)) 143 return !SI->isUnordered(); 144 if (DGNode::isFenceLike(I)) 145 return true; 146 return false; 147 }; 148 bool Is = IsOrdered(I); 149 assert((!Is || DGNode::isMemDepCandidate(I)) && 150 "An ordered instruction must be a MemDepCandidate!"); 151 return Is; 152 } 153 154 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, 155 DependencyType DepType) { 156 std::optional<MemoryLocation> DstLocOpt = 157 Utils::memoryLocationGetOrNone(DstI); 158 if (!DstLocOpt) 159 return true; 160 // Check aliasing. 161 assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && 162 "Expected a mem instr"); 163 // TODO: Check AABudget 164 ModRefInfo SrcModRef = 165 isOrdered(SrcI) 166 ? ModRefInfo::ModRef 167 : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt); 168 switch (DepType) { 169 case DependencyType::ReadAfterWrite: 170 case DependencyType::WriteAfterWrite: 171 return isModSet(SrcModRef); 172 case DependencyType::WriteAfterRead: 173 return isRefSet(SrcModRef); 174 default: 175 llvm_unreachable("Expected only RAW, WAW and WAR!"); 176 } 177 } 178 179 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { 180 DependencyType RoughDepType = getRoughDepType(SrcI, DstI); 181 switch (RoughDepType) { 182 case DependencyType::ReadAfterWrite: 183 case DependencyType::WriteAfterWrite: 184 case DependencyType::WriteAfterRead: 185 return alias(SrcI, DstI, RoughDepType); 186 case DependencyType::Control: 187 // Adding actual dep edges from PHIs/to terminator would just create too 188 // many edges, which would be bad for compile-time. 189 // So we ignore them in the DAG formation but handle them in the 190 // scheduler, while sorting the ready list. 191 return false; 192 case DependencyType::Other: 193 return true; 194 case DependencyType::None: 195 return false; 196 } 197 llvm_unreachable("Unknown DependencyType enum"); 198 } 199 200 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, 201 const Interval<MemDGNode> &SrcScanRange) { 202 assert(isa<MemDGNode>(DstN) && 203 "DstN is the mem dep destination, so it must be mem"); 204 Instruction *DstI = DstN.getInstruction(); 205 // Walk up the instruction chain from ScanRange bottom to top, looking for 206 // memory instrs that may alias. 207 for (MemDGNode &SrcN : reverse(SrcScanRange)) { 208 Instruction *SrcI = SrcN.getInstruction(); 209 if (hasDep(SrcI, DstI)) 210 DstN.addMemPred(&SrcN); 211 } 212 } 213 214 void DependencyGraph::setDefUseUnscheduledSuccs( 215 const Interval<Instruction> &NewInterval) { 216 // +---+ 217 // | | Def 218 // | | | 219 // | | v 220 // | | Use 221 // +---+ 222 // Set the intra-interval counters in NewInterval. 223 for (Instruction &I : NewInterval) { 224 for (Value *Op : I.operands()) { 225 auto *OpI = dyn_cast<Instruction>(Op); 226 if (OpI == nullptr) 227 continue; 228 if (!NewInterval.contains(OpI)) 229 continue; 230 auto *OpN = getNode(OpI); 231 if (OpN == nullptr) 232 continue; 233 ++OpN->UnscheduledSuccs; 234 } 235 } 236 237 // Now handle the cross-interval edges. 238 bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval); 239 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 240 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 241 // +---+ 242 // |Top| 243 // | | Def 244 // +---+ | 245 // | | v 246 // |Bot| Use 247 // | | 248 // +---+ 249 // Walk over all instructions in "BotInterval" and update the counter 250 // of operands that are in "TopInterval". 251 for (Instruction &BotI : BotInterval) { 252 for (Value *Op : BotI.operands()) { 253 auto *OpI = dyn_cast<Instruction>(Op); 254 if (OpI == nullptr) 255 continue; 256 if (!TopInterval.contains(OpI)) 257 continue; 258 auto *OpN = getNode(OpI); 259 if (OpN == nullptr) 260 continue; 261 ++OpN->UnscheduledSuccs; 262 } 263 } 264 } 265 266 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { 267 // Create Nodes only for the new sections of the DAG. 268 DGNode *LastN = getOrCreateNode(NewInterval.top()); 269 MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN); 270 for (Instruction &I : drop_begin(NewInterval)) { 271 auto *N = getOrCreateNode(&I); 272 // Build the Mem node chain. 273 if (auto *MemN = dyn_cast<MemDGNode>(N)) { 274 MemN->setPrevNode(LastMemN); 275 if (LastMemN != nullptr) 276 LastMemN->setNextNode(MemN); 277 LastMemN = MemN; 278 } 279 } 280 // Link new MemDGNode chain with the old one, if any. 281 if (!DAGInterval.empty()) { 282 bool NewIsAbove = NewInterval.comesBefore(DAGInterval); 283 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 284 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 285 MemDGNode *LinkTopN = 286 MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this); 287 MemDGNode *LinkBotN = 288 MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this); 289 assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!"); 290 if (LinkTopN != nullptr && LinkBotN != nullptr) { 291 LinkTopN->setNextNode(LinkBotN); 292 LinkBotN->setPrevNode(LinkTopN); 293 } 294 #ifndef NDEBUG 295 // TODO: Remove this once we've done enough testing. 296 // Check that the chain is well formed. 297 auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); 298 MemDGNode *ChainTopN = 299 MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); 300 MemDGNode *ChainBotN = 301 MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); 302 if (ChainTopN != nullptr && ChainBotN != nullptr) { 303 for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; 304 LastN = N, N = N->getNextNode()) { 305 assert(N == LastN->getNextNode() && "Bad chain!"); 306 assert(N->getPrevNode() == LastN && "Bad chain!"); 307 } 308 } 309 #endif // NDEBUG 310 } 311 312 setDefUseUnscheduledSuccs(NewInterval); 313 } 314 315 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { 316 if (Instrs.empty()) 317 return {}; 318 319 Interval<Instruction> InstrsInterval(Instrs); 320 Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval); 321 auto NewInterval = Union.getSingleDiff(DAGInterval); 322 if (NewInterval.empty()) 323 return {}; 324 325 createNewNodes(NewInterval); 326 327 // Create the dependencies. 328 // 329 // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. 330 // +---+ - - 331 // | | SrcN | | 332 // | | | | SrcRange | 333 // |New| v | | DstRange 334 // | | DstN - | 335 // | | | 336 // +---+ - 337 // We are scanning for deps with destination in NewInterval and sources in 338 // NewInterval until DstN, for each DstN. 339 auto FullScan = [this](const Interval<Instruction> Intvl) { 340 auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this); 341 if (!DstRange.empty()) { 342 for (MemDGNode &DstN : drop_begin(DstRange)) { 343 auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); 344 scanAndAddDeps(DstN, SrcRange); 345 } 346 } 347 }; 348 if (DAGInterval.empty()) { 349 assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); 350 FullScan(NewInterval); 351 } 352 // 2. The new section is below the old section. 353 // +---+ - 354 // | | | 355 // |Old| SrcN | 356 // | | | | 357 // +---+ | | SrcRange 358 // +---+ | | - 359 // | | | | | 360 // |New| v | | DstRange 361 // | | DstN - | 362 // | | | 363 // +---+ - 364 // We are scanning for deps with destination in NewInterval because the deps 365 // in DAGInterval have already been computed. We consider sources in the whole 366 // range including both NewInterval and DAGInterval until DstN, for each DstN. 367 else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { 368 auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 369 auto SrcRangeFull = MemDGNodeIntervalBuilder::make( 370 DAGInterval.getUnionInterval(NewInterval), *this); 371 for (MemDGNode &DstN : DstRange) { 372 auto SrcRange = 373 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); 374 scanAndAddDeps(DstN, SrcRange); 375 } 376 } 377 // 3. The new section is above the old section. 378 else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) { 379 // +---+ - - 380 // | | SrcN | | 381 // |New| | | SrcRange | DstRange 382 // | | v | | 383 // | | DstN - | 384 // | | | 385 // +---+ - 386 // +---+ 387 // |Old| 388 // | | 389 // +---+ 390 // When scanning for deps with destination in NewInterval we need to fully 391 // scan the interval. This is the same as the scanning for a new DAG. 392 FullScan(NewInterval); 393 394 // +---+ - 395 // | | | 396 // |New| SrcN | SrcRange 397 // | | | | 398 // | | | | 399 // | | | | 400 // +---+ | - 401 // +---+ | - 402 // |Old| v | DstRange 403 // | | DstN | 404 // +---+ - 405 // When scanning for deps with destination in DAGInterval we need to 406 // consider sources from the NewInterval only, because all intra-DAGInterval 407 // dependencies have already been created. 408 auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this); 409 auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 410 for (MemDGNode &DstN : DstRangeOld) 411 scanAndAddDeps(DstN, SrcRange); 412 } else { 413 llvm_unreachable("We don't expect extending in both directions!"); 414 } 415 416 DAGInterval = Union; 417 return NewInterval; 418 } 419 420 #ifndef NDEBUG 421 void DependencyGraph::print(raw_ostream &OS) const { 422 // InstrToNodeMap is unordered so we need to create an ordered vector. 423 SmallVector<DGNode *> Nodes; 424 Nodes.reserve(InstrToNodeMap.size()); 425 for (const auto &Pair : InstrToNodeMap) 426 Nodes.push_back(Pair.second.get()); 427 // Sort them based on which one comes first in the BB. 428 sort(Nodes, [](DGNode *N1, DGNode *N2) { 429 return N1->getInstruction()->comesBefore(N2->getInstruction()); 430 }); 431 for (auto *N : Nodes) 432 N->print(OS, /*PrintDeps=*/true); 433 } 434 435 void DependencyGraph::dump() const { 436 print(dbgs()); 437 dbgs() << "\n"; 438 } 439 #endif // NDEBUG 440 441 } // namespace llvm::sandboxir 442