1 //===- DependencyGraph.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/SandboxIR/Instruction.h" 12 #include "llvm/SandboxIR/Utils.h" 13 14 namespace llvm::sandboxir { 15 16 PredIterator::value_type PredIterator::operator*() { 17 // If it's a DGNode then we dereference the operand iterator. 18 if (!isa<MemDGNode>(N)) { 19 assert(OpIt != OpItE && "Can't dereference end iterator!"); 20 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 21 } 22 // It's a MemDGNode, so we check if we return either the use-def operand, 23 // or a mem predecessor. 24 if (OpIt != OpItE) 25 return DAG->getNode(cast<Instruction>((Value *)*OpIt)); 26 // It's a MemDGNode with OpIt == end, so we need to use MemIt. 27 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && 28 "Cant' dereference end iterator!"); 29 return *MemIt; 30 } 31 32 PredIterator &PredIterator::operator++() { 33 // If it's a DGNode then we increment the use-def iterator. 34 if (!isa<MemDGNode>(N)) { 35 assert(OpIt != OpItE && "Already at end!"); 36 ++OpIt; 37 // Skip operands that are not instructions. 38 OpIt = skipNonInstr(OpIt, OpItE); 39 return *this; 40 } 41 // It's a MemDGNode, so if we are not at the end of the use-def iterator we 42 // need to first increment that. 43 if (OpIt != OpItE) { 44 ++OpIt; 45 // Skip operands that are not instructions. 46 OpIt = skipNonInstr(OpIt, OpItE); 47 return *this; 48 } 49 // It's a MemDGNode with OpIt == end, so we need to increment MemIt. 50 assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!"); 51 ++MemIt; 52 return *this; 53 } 54 55 bool PredIterator::operator==(const PredIterator &Other) const { 56 assert(DAG == Other.DAG && "Iterators of different DAGs!"); 57 assert(N == Other.N && "Iterators of different nodes!"); 58 return OpIt == Other.OpIt && MemIt == Other.MemIt; 59 } 60 61 #ifndef NDEBUG 62 void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); } 63 void DGNode::dump() const { 64 print(dbgs()); 65 dbgs() << "\n"; 66 } 67 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const { 68 I->dumpOS(OS); 69 if (PrintDeps) { 70 // Print memory preds. 71 static constexpr const unsigned Indent = 4; 72 for (auto *Pred : MemPreds) { 73 OS.indent(Indent) << "<-"; 74 Pred->print(OS, false); 75 OS << "\n"; 76 } 77 } 78 } 79 #endif // NDEBUG 80 81 MemDGNode * 82 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl, 83 const DependencyGraph &DAG) { 84 Instruction *I = Intvl.top(); 85 Instruction *BeforeI = Intvl.bottom(); 86 // Walk down the chain looking for a mem-dep candidate instruction. 87 while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) 88 I = I->getNextNode(); 89 if (!DGNode::isMemDepNodeCandidate(I)) 90 return nullptr; 91 return cast<MemDGNode>(DAG.getNode(I)); 92 } 93 94 MemDGNode * 95 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl, 96 const DependencyGraph &DAG) { 97 Instruction *I = Intvl.bottom(); 98 Instruction *AfterI = Intvl.top(); 99 // Walk up the chain looking for a mem-dep candidate instruction. 100 while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) 101 I = I->getPrevNode(); 102 if (!DGNode::isMemDepNodeCandidate(I)) 103 return nullptr; 104 return cast<MemDGNode>(DAG.getNode(I)); 105 } 106 107 Interval<MemDGNode> 108 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs, 109 DependencyGraph &DAG) { 110 auto *TopMemN = getTopMemDGNode(Instrs, DAG); 111 // If we couldn't find a mem node in range TopN - BotN then it's empty. 112 if (TopMemN == nullptr) 113 return {}; 114 auto *BotMemN = getBotMemDGNode(Instrs, DAG); 115 assert(BotMemN != nullptr && "TopMemN should be null too!"); 116 // Now that we have the mem-dep nodes, create and return the range. 117 return Interval<MemDGNode>(TopMemN, BotMemN); 118 } 119 120 DependencyGraph::DependencyType 121 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { 122 // TODO: Perhaps compile-time improvement by skipping if neither is mem? 123 if (FromI->mayWriteToMemory()) { 124 if (ToI->mayReadFromMemory()) 125 return DependencyType::ReadAfterWrite; 126 if (ToI->mayWriteToMemory()) 127 return DependencyType::WriteAfterWrite; 128 } else if (FromI->mayReadFromMemory()) { 129 if (ToI->mayWriteToMemory()) 130 return DependencyType::WriteAfterRead; 131 } 132 if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI)) 133 return DependencyType::Control; 134 if (ToI->isTerminator()) 135 return DependencyType::Control; 136 if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) || 137 DGNode::isStackSaveOrRestoreIntrinsic(ToI)) 138 return DependencyType::Other; 139 return DependencyType::None; 140 } 141 142 static bool isOrdered(Instruction *I) { 143 auto IsOrdered = [](Instruction *I) { 144 if (auto *LI = dyn_cast<LoadInst>(I)) 145 return !LI->isUnordered(); 146 if (auto *SI = dyn_cast<StoreInst>(I)) 147 return !SI->isUnordered(); 148 if (DGNode::isFenceLike(I)) 149 return true; 150 return false; 151 }; 152 bool Is = IsOrdered(I); 153 assert((!Is || DGNode::isMemDepCandidate(I)) && 154 "An ordered instruction must be a MemDepCandidate!"); 155 return Is; 156 } 157 158 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, 159 DependencyType DepType) { 160 std::optional<MemoryLocation> DstLocOpt = 161 Utils::memoryLocationGetOrNone(DstI); 162 if (!DstLocOpt) 163 return true; 164 // Check aliasing. 165 assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && 166 "Expected a mem instr"); 167 // TODO: Check AABudget 168 ModRefInfo SrcModRef = 169 isOrdered(SrcI) 170 ? ModRefInfo::ModRef 171 : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt); 172 switch (DepType) { 173 case DependencyType::ReadAfterWrite: 174 case DependencyType::WriteAfterWrite: 175 return isModSet(SrcModRef); 176 case DependencyType::WriteAfterRead: 177 return isRefSet(SrcModRef); 178 default: 179 llvm_unreachable("Expected only RAW, WAW and WAR!"); 180 } 181 } 182 183 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { 184 DependencyType RoughDepType = getRoughDepType(SrcI, DstI); 185 switch (RoughDepType) { 186 case DependencyType::ReadAfterWrite: 187 case DependencyType::WriteAfterWrite: 188 case DependencyType::WriteAfterRead: 189 return alias(SrcI, DstI, RoughDepType); 190 case DependencyType::Control: 191 // Adding actual dep edges from PHIs/to terminator would just create too 192 // many edges, which would be bad for compile-time. 193 // So we ignore them in the DAG formation but handle them in the 194 // scheduler, while sorting the ready list. 195 return false; 196 case DependencyType::Other: 197 return true; 198 case DependencyType::None: 199 return false; 200 } 201 llvm_unreachable("Unknown DependencyType enum"); 202 } 203 204 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, 205 const Interval<MemDGNode> &SrcScanRange) { 206 assert(isa<MemDGNode>(DstN) && 207 "DstN is the mem dep destination, so it must be mem"); 208 Instruction *DstI = DstN.getInstruction(); 209 // Walk up the instruction chain from ScanRange bottom to top, looking for 210 // memory instrs that may alias. 211 for (MemDGNode &SrcN : reverse(SrcScanRange)) { 212 Instruction *SrcI = SrcN.getInstruction(); 213 if (hasDep(SrcI, DstI)) 214 DstN.addMemPred(&SrcN); 215 } 216 } 217 218 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) { 219 // Create Nodes only for the new sections of the DAG. 220 DGNode *LastN = getOrCreateNode(NewInterval.top()); 221 MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN); 222 for (Instruction &I : drop_begin(NewInterval)) { 223 auto *N = getOrCreateNode(&I); 224 // Build the Mem node chain. 225 if (auto *MemN = dyn_cast<MemDGNode>(N)) { 226 MemN->setPrevNode(LastMemN); 227 if (LastMemN != nullptr) 228 LastMemN->setNextNode(MemN); 229 LastMemN = MemN; 230 } 231 } 232 // Link new MemDGNode chain with the old one, if any. 233 if (!DAGInterval.empty()) { 234 bool NewIsAbove = NewInterval.comesBefore(DAGInterval); 235 const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; 236 const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; 237 MemDGNode *LinkTopN = 238 MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this); 239 MemDGNode *LinkBotN = 240 MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this); 241 assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!"); 242 if (LinkTopN != nullptr && LinkBotN != nullptr) { 243 LinkTopN->setNextNode(LinkBotN); 244 LinkBotN->setPrevNode(LinkTopN); 245 } 246 #ifndef NDEBUG 247 // TODO: Remove this once we've done enough testing. 248 // Check that the chain is well formed. 249 auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); 250 MemDGNode *ChainTopN = 251 MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); 252 MemDGNode *ChainBotN = 253 MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); 254 if (ChainTopN != nullptr && ChainBotN != nullptr) { 255 for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; 256 LastN = N, N = N->getNextNode()) { 257 assert(N == LastN->getNextNode() && "Bad chain!"); 258 assert(N->getPrevNode() == LastN && "Bad chain!"); 259 } 260 } 261 #endif // NDEBUG 262 } 263 } 264 265 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) { 266 if (Instrs.empty()) 267 return {}; 268 269 Interval<Instruction> InstrsInterval(Instrs); 270 Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval); 271 auto NewInterval = Union.getSingleDiff(DAGInterval); 272 if (NewInterval.empty()) 273 return {}; 274 275 createNewNodes(NewInterval); 276 277 // Create the dependencies. 278 // 279 // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval. 280 // +---+ - - 281 // | | SrcN | | 282 // | | | | SrcRange | 283 // |New| v | | DstRange 284 // | | DstN - | 285 // | | | 286 // +---+ - 287 // We are scanning for deps with destination in NewInterval and sources in 288 // NewInterval until DstN, for each DstN. 289 auto FullScan = [this](const Interval<Instruction> Intvl) { 290 auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this); 291 if (!DstRange.empty()) { 292 for (MemDGNode &DstN : drop_begin(DstRange)) { 293 auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode()); 294 scanAndAddDeps(DstN, SrcRange); 295 } 296 } 297 }; 298 if (DAGInterval.empty()) { 299 assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); 300 FullScan(NewInterval); 301 } 302 // 2. The new section is below the old section. 303 // +---+ - 304 // | | | 305 // |Old| SrcN | 306 // | | | | 307 // +---+ | | SrcRange 308 // +---+ | | - 309 // | | | | | 310 // |New| v | | DstRange 311 // | | DstN - | 312 // | | | 313 // +---+ - 314 // We are scanning for deps with destination in NewInterval because the deps 315 // in DAGInterval have already been computed. We consider sources in the whole 316 // range including both NewInterval and DAGInterval until DstN, for each DstN. 317 else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { 318 auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 319 auto SrcRangeFull = MemDGNodeIntervalBuilder::make( 320 DAGInterval.getUnionInterval(NewInterval), *this); 321 for (MemDGNode &DstN : DstRange) { 322 auto SrcRange = 323 Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode()); 324 scanAndAddDeps(DstN, SrcRange); 325 } 326 } 327 // 3. The new section is above the old section. 328 else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) { 329 // +---+ - - 330 // | | SrcN | | 331 // |New| | | SrcRange | DstRange 332 // | | v | | 333 // | | DstN - | 334 // | | | 335 // +---+ - 336 // +---+ 337 // |Old| 338 // | | 339 // +---+ 340 // When scanning for deps with destination in NewInterval we need to fully 341 // scan the interval. This is the same as the scanning for a new DAG. 342 FullScan(NewInterval); 343 344 // +---+ - 345 // | | | 346 // |New| SrcN | SrcRange 347 // | | | | 348 // | | | | 349 // | | | | 350 // +---+ | - 351 // +---+ | - 352 // |Old| v | DstRange 353 // | | DstN | 354 // +---+ - 355 // When scanning for deps with destination in DAGInterval we need to 356 // consider sources from the NewInterval only, because all intra-DAGInterval 357 // dependencies have already been created. 358 auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this); 359 auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); 360 for (MemDGNode &DstN : DstRangeOld) 361 scanAndAddDeps(DstN, SrcRange); 362 } else { 363 llvm_unreachable("We don't expect extending in both directions!"); 364 } 365 366 DAGInterval = Union; 367 return NewInterval; 368 } 369 370 #ifndef NDEBUG 371 void DependencyGraph::print(raw_ostream &OS) const { 372 // InstrToNodeMap is unordered so we need to create an ordered vector. 373 SmallVector<DGNode *> Nodes; 374 Nodes.reserve(InstrToNodeMap.size()); 375 for (const auto &Pair : InstrToNodeMap) 376 Nodes.push_back(Pair.second.get()); 377 // Sort them based on which one comes first in the BB. 378 sort(Nodes, [](DGNode *N1, DGNode *N2) { 379 return N1->getInstruction()->comesBefore(N2->getInstruction()); 380 }); 381 for (auto *N : Nodes) 382 N->print(OS, /*PrintDeps=*/true); 383 } 384 385 void DependencyGraph::dump() const { 386 print(dbgs()); 387 dbgs() << "\n"; 388 } 389 #endif // NDEBUG 390 391 } // namespace llvm::sandboxir 392