xref: /llvm-project/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp (revision 4db57ab958f5bac1d85927a955f989625badf962)
1 //===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // The SuspendCrossingInfo maintains data that allows to answer a question
9 // whether given two BasicBlocks A and B there is a path from A to B that
10 // passes through a suspend point. Note, SuspendCrossingInfo is invalidated
11 // by changes to the CFG including adding/removing BBs due to its use of BB
12 // ptrs in the BlockToIndexMapping.
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Coroutines/SuspendCrossingInfo.h"
16 #include "llvm/IR/ModuleSlotTracker.h"
17 
18 // The "coro-suspend-crossing" flag is very noisy. There is another debug type,
19 // "coro-frame", which results in leaner debug spew.
20 #define DEBUG_TYPE "coro-suspend-crossing"
21 
22 namespace llvm {
23 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
24 static void dumpBasicBlockLabel(const BasicBlock *BB, ModuleSlotTracker &MST) {
25   if (BB->hasName()) {
26     dbgs() << BB->getName();
27     return;
28   }
29 
30   dbgs() << MST.getLocalSlot(BB);
31 }
32 
33 LLVM_DUMP_METHOD void
34 SuspendCrossingInfo::dump(StringRef Label, BitVector const &BV,
35                           const ReversePostOrderTraversal<Function *> &RPOT,
36                           ModuleSlotTracker &MST) const {
37   dbgs() << Label << ":";
38   for (const BasicBlock *BB : RPOT) {
39     auto BBNo = Mapping.blockToIndex(BB);
40     if (BV[BBNo]) {
41       dbgs() << " ";
42       dumpBasicBlockLabel(BB, MST);
43     }
44   }
45   dbgs() << "\n";
46 }
47 
48 LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
49   if (Block.empty())
50     return;
51 
52   BasicBlock *const B = Mapping.indexToBlock(0);
53   Function *F = B->getParent();
54 
55   ModuleSlotTracker MST(F->getParent());
56   MST.incorporateFunction(*F);
57 
58   ReversePostOrderTraversal<Function *> RPOT(F);
59   for (const BasicBlock *BB : RPOT) {
60     auto BBNo = Mapping.blockToIndex(BB);
61     dumpBasicBlockLabel(BB, MST);
62     dbgs() << ":\n";
63     dump("   Consumes", Block[BBNo].Consumes, RPOT, MST);
64     dump("      Kills", Block[BBNo].Kills, RPOT, MST);
65   }
66   dbgs() << "\n";
67 }
68 #endif
69 
70 bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From,
71                                                       BasicBlock *To) const {
72   size_t const FromIndex = Mapping.blockToIndex(From);
73   size_t const ToIndex = Mapping.blockToIndex(To);
74   bool const Result = Block[ToIndex].Kills[FromIndex];
75   LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
76                                 << " crosses suspend point\n");
77   return Result;
78 }
79 
80 bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint(
81     BasicBlock *From, BasicBlock *To) const {
82   size_t const FromIndex = Mapping.blockToIndex(From);
83   size_t const ToIndex = Mapping.blockToIndex(To);
84   bool Result = Block[ToIndex].Kills[FromIndex] ||
85                 (From == To && Block[ToIndex].KillLoop);
86   LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
87                                 << " crosses suspend point (path or loop)\n");
88   return Result;
89 }
90 
91 template <bool Initialize>
92 bool SuspendCrossingInfo::computeBlockData(
93     const ReversePostOrderTraversal<Function *> &RPOT) {
94   bool Changed = false;
95 
96   for (const BasicBlock *BB : RPOT) {
97     auto BBNo = Mapping.blockToIndex(BB);
98     auto &B = Block[BBNo];
99 
100     // We don't need to count the predecessors when initialization.
101     if constexpr (!Initialize)
102       // If all the predecessors of the current Block don't change,
103       // the BlockData for the current block must not change too.
104       if (all_of(predecessors(B), [this](BasicBlock *BB) {
105             return !Block[Mapping.blockToIndex(BB)].Changed;
106           })) {
107         B.Changed = false;
108         continue;
109       }
110 
111     // Saved Consumes and Kills bitsets so that it is easy to see
112     // if anything changed after propagation.
113     auto SavedConsumes = B.Consumes;
114     auto SavedKills = B.Kills;
115 
116     for (BasicBlock *PI : predecessors(B)) {
117       auto PrevNo = Mapping.blockToIndex(PI);
118       auto &P = Block[PrevNo];
119 
120       // Propagate Kills and Consumes from predecessors into B.
121       B.Consumes |= P.Consumes;
122       B.Kills |= P.Kills;
123 
124       // If block P is a suspend block, it should propagate kills into block
125       // B for every block P consumes.
126       if (P.Suspend)
127         B.Kills |= P.Consumes;
128     }
129 
130     if (B.Suspend) {
131       // If block B is a suspend block, it should kill all of the blocks it
132       // consumes.
133       B.Kills |= B.Consumes;
134     } else if (B.End) {
135       // If block B is an end block, it should not propagate kills as the
136       // blocks following coro.end() are reached during initial invocation
137       // of the coroutine while all the data are still available on the
138       // stack or in the registers.
139       B.Kills.reset();
140     } else {
141       // This is reached when B block it not Suspend nor coro.end and it
142       // need to make sure that it is not in the kill set.
143       B.KillLoop |= B.Kills[BBNo];
144       B.Kills.reset(BBNo);
145     }
146 
147     if constexpr (!Initialize) {
148       B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
149       Changed |= B.Changed;
150     }
151   }
152 
153   return Changed;
154 }
155 
156 SuspendCrossingInfo::SuspendCrossingInfo(
157     Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends,
158     const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds)
159     : Mapping(F) {
160   const size_t N = Mapping.size();
161   Block.resize(N);
162 
163   // Initialize every block so that it consumes itself
164   for (size_t I = 0; I < N; ++I) {
165     auto &B = Block[I];
166     B.Consumes.resize(N);
167     B.Kills.resize(N);
168     B.Consumes.set(I);
169     B.Changed = true;
170   }
171 
172   // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
173   // the code beyond coro.end is reachable during initial invocation of the
174   // coroutine.
175   for (auto *CE : CoroEnds) {
176     // Verify CoroEnd was normalized
177     assert(CE->getParent()->getFirstInsertionPt() == CE->getIterator() &&
178            CE->getParent()->size() <= 2 && "CoroEnd must be in its own BB");
179 
180     getBlockData(CE->getParent()).End = true;
181   }
182 
183   // Mark all suspend blocks and indicate that they kill everything they
184   // consume. Note, that crossing coro.save also requires a spill, as any code
185   // between coro.save and coro.suspend may resume the coroutine and all of the
186   // state needs to be saved by that time.
187   auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
188     BasicBlock *SuspendBlock = BarrierInst->getParent();
189     auto &B = getBlockData(SuspendBlock);
190     B.Suspend = true;
191     B.Kills |= B.Consumes;
192   };
193   for (auto *CSI : CoroSuspends) {
194     // Verify CoroSuspend was normalized
195     assert(CSI->getParent()->getFirstInsertionPt() == CSI->getIterator() &&
196            CSI->getParent()->size() <= 2 &&
197            "CoroSuspend must be in its own BB");
198 
199     markSuspendBlock(CSI);
200     if (auto *Save = CSI->getCoroSave())
201       markSuspendBlock(Save);
202   }
203 
204   // It is considered to be faster to use RPO traversal for forward-edges
205   // dataflow analysis.
206   ReversePostOrderTraversal<Function *> RPOT(&F);
207   computeBlockData</*Initialize=*/true>(RPOT);
208   while (computeBlockData</*Initialize*/ false>(RPOT))
209     ;
210 
211   LLVM_DEBUG(dump());
212 }
213 
214 } // namespace llvm
215