1 //===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // The SuspendCrossingInfo maintains data that allows to answer a question 9 // whether given two BasicBlocks A and B there is a path from A to B that 10 // passes through a suspend point. Note, SuspendCrossingInfo is invalidated 11 // by changes to the CFG including adding/removing BBs due to its use of BB 12 // ptrs in the BlockToIndexMapping. 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Coroutines/SuspendCrossingInfo.h" 16 #include "llvm/IR/ModuleSlotTracker.h" 17 18 // The "coro-suspend-crossing" flag is very noisy. There is another debug type, 19 // "coro-frame", which results in leaner debug spew. 20 #define DEBUG_TYPE "coro-suspend-crossing" 21 22 namespace llvm { 23 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 24 static void dumpBasicBlockLabel(const BasicBlock *BB, ModuleSlotTracker &MST) { 25 if (BB->hasName()) { 26 dbgs() << BB->getName(); 27 return; 28 } 29 30 dbgs() << MST.getLocalSlot(BB); 31 } 32 33 LLVM_DUMP_METHOD void 34 SuspendCrossingInfo::dump(StringRef Label, BitVector const &BV, 35 const ReversePostOrderTraversal<Function *> &RPOT, 36 ModuleSlotTracker &MST) const { 37 dbgs() << Label << ":"; 38 for (const BasicBlock *BB : RPOT) { 39 auto BBNo = Mapping.blockToIndex(BB); 40 if (BV[BBNo]) { 41 dbgs() << " "; 42 dumpBasicBlockLabel(BB, MST); 43 } 44 } 45 dbgs() << "\n"; 46 } 47 48 LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const { 49 if (Block.empty()) 50 return; 51 52 BasicBlock *const B = Mapping.indexToBlock(0); 53 Function *F = B->getParent(); 54 55 ModuleSlotTracker MST(F->getParent()); 56 MST.incorporateFunction(*F); 57 58 ReversePostOrderTraversal<Function *> RPOT(F); 59 for (const BasicBlock *BB : RPOT) { 60 auto BBNo = Mapping.blockToIndex(BB); 61 dumpBasicBlockLabel(BB, MST); 62 dbgs() << ":\n"; 63 dump(" Consumes", Block[BBNo].Consumes, RPOT, MST); 64 dump(" Kills", Block[BBNo].Kills, RPOT, MST); 65 } 66 dbgs() << "\n"; 67 } 68 #endif 69 70 bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From, 71 BasicBlock *To) const { 72 size_t const FromIndex = Mapping.blockToIndex(From); 73 size_t const ToIndex = Mapping.blockToIndex(To); 74 bool const Result = Block[ToIndex].Kills[FromIndex]; 75 LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName() 76 << " crosses suspend point\n"); 77 return Result; 78 } 79 80 bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint( 81 BasicBlock *From, BasicBlock *To) const { 82 size_t const FromIndex = Mapping.blockToIndex(From); 83 size_t const ToIndex = Mapping.blockToIndex(To); 84 bool Result = Block[ToIndex].Kills[FromIndex] || 85 (From == To && Block[ToIndex].KillLoop); 86 LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName() 87 << " crosses suspend point (path or loop)\n"); 88 return Result; 89 } 90 91 template <bool Initialize> 92 bool SuspendCrossingInfo::computeBlockData( 93 const ReversePostOrderTraversal<Function *> &RPOT) { 94 bool Changed = false; 95 96 for (const BasicBlock *BB : RPOT) { 97 auto BBNo = Mapping.blockToIndex(BB); 98 auto &B = Block[BBNo]; 99 100 // We don't need to count the predecessors when initialization. 101 if constexpr (!Initialize) 102 // If all the predecessors of the current Block don't change, 103 // the BlockData for the current block must not change too. 104 if (all_of(predecessors(B), [this](BasicBlock *BB) { 105 return !Block[Mapping.blockToIndex(BB)].Changed; 106 })) { 107 B.Changed = false; 108 continue; 109 } 110 111 // Saved Consumes and Kills bitsets so that it is easy to see 112 // if anything changed after propagation. 113 auto SavedConsumes = B.Consumes; 114 auto SavedKills = B.Kills; 115 116 for (BasicBlock *PI : predecessors(B)) { 117 auto PrevNo = Mapping.blockToIndex(PI); 118 auto &P = Block[PrevNo]; 119 120 // Propagate Kills and Consumes from predecessors into B. 121 B.Consumes |= P.Consumes; 122 B.Kills |= P.Kills; 123 124 // If block P is a suspend block, it should propagate kills into block 125 // B for every block P consumes. 126 if (P.Suspend) 127 B.Kills |= P.Consumes; 128 } 129 130 if (B.Suspend) { 131 // If block B is a suspend block, it should kill all of the blocks it 132 // consumes. 133 B.Kills |= B.Consumes; 134 } else if (B.End) { 135 // If block B is an end block, it should not propagate kills as the 136 // blocks following coro.end() are reached during initial invocation 137 // of the coroutine while all the data are still available on the 138 // stack or in the registers. 139 B.Kills.reset(); 140 } else { 141 // This is reached when B block it not Suspend nor coro.end and it 142 // need to make sure that it is not in the kill set. 143 B.KillLoop |= B.Kills[BBNo]; 144 B.Kills.reset(BBNo); 145 } 146 147 if constexpr (!Initialize) { 148 B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes); 149 Changed |= B.Changed; 150 } 151 } 152 153 return Changed; 154 } 155 156 SuspendCrossingInfo::SuspendCrossingInfo( 157 Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends, 158 const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds) 159 : Mapping(F) { 160 const size_t N = Mapping.size(); 161 Block.resize(N); 162 163 // Initialize every block so that it consumes itself 164 for (size_t I = 0; I < N; ++I) { 165 auto &B = Block[I]; 166 B.Consumes.resize(N); 167 B.Kills.resize(N); 168 B.Consumes.set(I); 169 B.Changed = true; 170 } 171 172 // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as 173 // the code beyond coro.end is reachable during initial invocation of the 174 // coroutine. 175 for (auto *CE : CoroEnds) { 176 // Verify CoroEnd was normalized 177 assert(CE->getParent()->getFirstInsertionPt() == CE->getIterator() && 178 CE->getParent()->size() <= 2 && "CoroEnd must be in its own BB"); 179 180 getBlockData(CE->getParent()).End = true; 181 } 182 183 // Mark all suspend blocks and indicate that they kill everything they 184 // consume. Note, that crossing coro.save also requires a spill, as any code 185 // between coro.save and coro.suspend may resume the coroutine and all of the 186 // state needs to be saved by that time. 187 auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) { 188 BasicBlock *SuspendBlock = BarrierInst->getParent(); 189 auto &B = getBlockData(SuspendBlock); 190 B.Suspend = true; 191 B.Kills |= B.Consumes; 192 }; 193 for (auto *CSI : CoroSuspends) { 194 // Verify CoroSuspend was normalized 195 assert(CSI->getParent()->getFirstInsertionPt() == CSI->getIterator() && 196 CSI->getParent()->size() <= 2 && 197 "CoroSuspend must be in its own BB"); 198 199 markSuspendBlock(CSI); 200 if (auto *Save = CSI->getCoroSave()) 201 markSuspendBlock(Save); 202 } 203 204 // It is considered to be faster to use RPO traversal for forward-edges 205 // dataflow analysis. 206 ReversePostOrderTraversal<Function *> RPOT(&F); 207 computeBlockData</*Initialize=*/true>(RPOT); 208 while (computeBlockData</*Initialize*/ false>(RPOT)) 209 ; 210 211 LLVM_DEBUG(dump()); 212 } 213 214 } // namespace llvm 215