xref: /netbsd-src/external/apache2/llvm/dist/llvm/lib/Analysis/SyncDependenceAnalysis.cpp (revision 82d56013d7b633d116a93943de88e08335357a7c)
1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm that returns for a divergent branch
10 // the set of basic blocks whose phi nodes become divergent due to divergent
11 // control. These are the blocks that are reachable by two disjoint paths from
12 // the branch or loop exits that have a reaching path that is disjoint from a
13 // path to the loop latch.
14 //
15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
16 // control-induced divergence in phi nodes.
17 //
18 // -- Summary --
19 // The SyncDependenceAnalysis lazily computes sync dependences [3].
20 // The analysis evaluates the disjoint path criterion [2] by a reduction
21 // to SSA construction. The SSA construction algorithm is implemented as
22 // a simple data-flow analysis [1].
23 //
24 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
25 // [2] "Efficiently Computing Static Single Assignment Form
26 //     and the Control Dependence Graph", TOPLAS '91,
27 //           Cytron, Ferrante, Rosen, Wegman and Zadeck
28 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
29 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
30 //
31 // -- Sync dependence --
32 // Sync dependence [4] characterizes the control flow aspect of the
33 // propagation of branch divergence. For example,
34 //
35 //   %cond = icmp slt i32 %tid, 10
36 //   br i1 %cond, label %then, label %else
37 // then:
38 //   br label %merge
39 // else:
40 //   br label %merge
41 // merge:
42 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
43 //
44 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
45 // because %tid is not on its use-def chains, %a is sync dependent on %tid
46 // because the branch "br i1 %cond" depends on %tid and affects which value %a
47 // is assigned to.
48 //
49 // -- Reduction to SSA construction --
50 // There are two disjoint paths from A to X, if a certain variant of SSA
51 // construction places a phi node in X under the following set-up scheme [2].
52 //
53 // This variant of SSA construction ignores incoming undef values.
54 // That is paths from the entry without a definition do not result in
55 // phi nodes.
56 //
57 //       entry
58 //     /      \
59 //    A        \
60 //  /   \       Y
61 // B     C     /
62 //  \   /  \  /
63 //    D     E
64 //     \   /
65 //       F
66 // Assume that A contains a divergent branch. We are interested
67 // in the set of all blocks where each block is reachable from A
68 // via two disjoint paths. This would be the set {D, F} in this
69 // case.
70 // To generally reduce this query to SSA construction we introduce
71 // a virtual variable x and assign to x different values in each
72 // successor block of A.
73 //           entry
74 //         /      \
75 //        A        \
76 //      /   \       Y
77 // x = 0   x = 1   /
78 //      \  /   \  /
79 //        D     E
80 //         \   /
81 //           F
82 // Our flavor of SSA construction for x will construct the following
83 //            entry
84 //          /      \
85 //         A        \
86 //       /   \       Y
87 // x0 = 0   x1 = 1  /
88 //       \   /   \ /
89 //      x2=phi    E
90 //         \     /
91 //          x3=phi
92 // The blocks D and F contain phi nodes and are thus each reachable
93 // by two disjoins paths from A.
94 //
95 // -- Remarks --
96 // In case of loop exits we need to check the disjoint path criterion for loops
97 // [2]. To this end, we check whether the definition of x differs between the
98 // loop exit and the loop header (_after_ SSA construction).
99 //
100 //===----------------------------------------------------------------------===//
101 #include "llvm/Analysis/SyncDependenceAnalysis.h"
102 #include "llvm/ADT/PostOrderIterator.h"
103 #include "llvm/ADT/SmallPtrSet.h"
104 #include "llvm/Analysis/PostDominators.h"
105 #include "llvm/IR/BasicBlock.h"
106 #include "llvm/IR/CFG.h"
107 #include "llvm/IR/Dominators.h"
108 #include "llvm/IR/Function.h"
109 
110 #include <functional>
111 #include <stack>
112 #include <unordered_set>
113 
114 #define DEBUG_TYPE "sync-dependence"
115 
116 // The SDA algorithm operates on a modified CFG - we modify the edges leaving
117 // loop headers as follows:
118 //
119 // * We remove all edges leaving all loop headers.
120 // * We add additional edges from the loop headers to their exit blocks.
121 //
122 // The modification is virtual, that is whenever we visit a loop header we
123 // pretend it had different successors.
124 namespace {
125 using namespace llvm;
126 
127 // Custom Post-Order Traveral
128 //
129 // We cannot use the vanilla (R)PO computation of LLVM because:
130 // * We (virtually) modify the CFG.
131 // * We want a loop-compact block enumeration, that is the numbers assigned by
132 //   the traveral to the blocks of a loop are an interval.
133 using POCB = std::function<void(const BasicBlock &)>;
134 using VisitedSet = std::set<const BasicBlock *>;
135 using BlockStack = std::vector<const BasicBlock *>;
136 
137 // forward
138 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
139                           VisitedSet &Finalized);
140 
141 // for a nested region (top-level loop or nested loop)
computeStackPO(BlockStack & Stack,const LoopInfo & LI,Loop * Loop,POCB CallBack,VisitedSet & Finalized)142 static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
143                            POCB CallBack, VisitedSet &Finalized) {
144   const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
145   while (!Stack.empty()) {
146     const auto *NextBB = Stack.back();
147 
148     auto *NestedLoop = LI.getLoopFor(NextBB);
149     bool IsNestedLoop = NestedLoop != Loop;
150 
151     // Treat the loop as a node
152     if (IsNestedLoop) {
153       SmallVector<BasicBlock *, 3> NestedExits;
154       NestedLoop->getUniqueExitBlocks(NestedExits);
155       bool PushedNodes = false;
156       for (const auto *NestedExitBB : NestedExits) {
157         if (NestedExitBB == LoopHeader)
158           continue;
159         if (Loop && !Loop->contains(NestedExitBB))
160           continue;
161         if (Finalized.count(NestedExitBB))
162           continue;
163         PushedNodes = true;
164         Stack.push_back(NestedExitBB);
165       }
166       if (!PushedNodes) {
167         // All loop exits finalized -> finish this node
168         Stack.pop_back();
169         computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
170       }
171       continue;
172     }
173 
174     // DAG-style
175     bool PushedNodes = false;
176     for (const auto *SuccBB : successors(NextBB)) {
177       if (SuccBB == LoopHeader)
178         continue;
179       if (Loop && !Loop->contains(SuccBB))
180         continue;
181       if (Finalized.count(SuccBB))
182         continue;
183       PushedNodes = true;
184       Stack.push_back(SuccBB);
185     }
186     if (!PushedNodes) {
187       // Never push nodes twice
188       Stack.pop_back();
189       if (!Finalized.insert(NextBB).second)
190         continue;
191       CallBack(*NextBB);
192     }
193   }
194 }
195 
computeTopLevelPO(Function & F,const LoopInfo & LI,POCB CallBack)196 static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
197   VisitedSet Finalized;
198   BlockStack Stack;
199   Stack.reserve(24); // FIXME made-up number
200   Stack.push_back(&F.getEntryBlock());
201   computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
202 }
203 
computeLoopPO(const LoopInfo & LI,Loop & Loop,POCB CallBack,VisitedSet & Finalized)204 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
205                           VisitedSet &Finalized) {
206   /// Call CallBack on all loop blocks.
207   std::vector<const BasicBlock *> Stack;
208   const auto *LoopHeader = Loop.getHeader();
209 
210   // Visit the header last
211   Finalized.insert(LoopHeader);
212   CallBack(*LoopHeader);
213 
214   // Initialize with immediate successors
215   for (const auto *BB : successors(LoopHeader)) {
216     if (!Loop.contains(BB))
217       continue;
218     if (BB == LoopHeader)
219       continue;
220     Stack.push_back(BB);
221   }
222 
223   // Compute PO inside region
224   computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
225 }
226 
227 } // namespace
228 
229 namespace llvm {
230 
231 ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
232 
SyncDependenceAnalysis(const DominatorTree & DT,const PostDominatorTree & PDT,const LoopInfo & LI)233 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
234                                                const PostDominatorTree &PDT,
235                                                const LoopInfo &LI)
236     : DT(DT), PDT(PDT), LI(LI) {
237   computeTopLevelPO(*DT.getRoot()->getParent(), LI,
238                     [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
239 }
240 
~SyncDependenceAnalysis()241 SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
242 
243 // divergence propagator for reducible CFGs
244 struct DivergencePropagator {
245   const ModifiedPO &LoopPOT;
246   const DominatorTree &DT;
247   const PostDominatorTree &PDT;
248   const LoopInfo &LI;
249   const BasicBlock &DivTermBlock;
250 
251   // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
252   //   block B
253   // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
254   // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
255   // from X or B is an immediate successor of X (initial value).
256   using BlockLabelVec = std::vector<const BasicBlock *>;
257   BlockLabelVec BlockLabels;
258   // divergent join and loop exit descriptor.
259   std::unique_ptr<ControlDivergenceDesc> DivDesc;
260 
DivergencePropagatorllvm::DivergencePropagator261   DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
262                        const PostDominatorTree &PDT, const LoopInfo &LI,
263                        const BasicBlock &DivTermBlock)
264       : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
265         BlockLabels(LoopPOT.size(), nullptr),
266         DivDesc(new ControlDivergenceDesc) {}
267 
printDefsllvm::DivergencePropagator268   void printDefs(raw_ostream &Out) {
269     Out << "Propagator::BlockLabels {\n";
270     for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
271       const auto *Label = BlockLabels[BlockIdx];
272       Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
273           << ") : ";
274       if (!Label) {
275         Out << "<null>\n";
276       } else {
277         Out << Label->getName() << "\n";
278       }
279     }
280     Out << "}\n";
281   }
282 
283   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
284   // causes a divergent join.
computeJoinllvm::DivergencePropagator285   bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
286     auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
287 
288     // unset or same reaching label
289     const auto *OldLabel = BlockLabels[SuccIdx];
290     if (!OldLabel || (OldLabel == &PushedLabel)) {
291       BlockLabels[SuccIdx] = &PushedLabel;
292       return false;
293     }
294 
295     // Update the definition
296     BlockLabels[SuccIdx] = &SuccBlock;
297     return true;
298   }
299 
300   // visiting a virtual loop exit edge from the loop header --> temporal
301   // divergence on join
visitLoopExitEdgellvm::DivergencePropagator302   bool visitLoopExitEdge(const BasicBlock &ExitBlock,
303                          const BasicBlock &DefBlock, bool FromParentLoop) {
304     // Pushing from a non-parent loop cannot cause temporal divergence.
305     if (!FromParentLoop)
306       return visitEdge(ExitBlock, DefBlock);
307 
308     if (!computeJoin(ExitBlock, DefBlock))
309       return false;
310 
311     // Identified a divergent loop exit
312     DivDesc->LoopDivBlocks.insert(&ExitBlock);
313     LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
314                       << "\n");
315     return true;
316   }
317 
318   // process \p SuccBlock with reaching definition \p DefBlock
visitEdgellvm::DivergencePropagator319   bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
320     if (!computeJoin(SuccBlock, DefBlock))
321       return false;
322 
323     // Divergent, disjoint paths join.
324     DivDesc->JoinDivBlocks.insert(&SuccBlock);
325     LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
326     return true;
327   }
328 
computeJoinPointsllvm::DivergencePropagator329   std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
330     assert(DivDesc);
331 
332     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
333                       << "\n");
334 
335     const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
336 
337     // Early stopping criterion
338     int FloorIdx = LoopPOT.size() - 1;
339     const BasicBlock *FloorLabel = nullptr;
340 
341     // bootstrap with branch targets
342     int BlockIdx = 0;
343 
344     for (const auto *SuccBlock : successors(&DivTermBlock)) {
345       auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
346       BlockLabels[SuccIdx] = SuccBlock;
347 
348       // Find the successor with the highest index to start with
349       BlockIdx = std::max<int>(BlockIdx, SuccIdx);
350       FloorIdx = std::min<int>(FloorIdx, SuccIdx);
351 
352       // Identify immediate divergent loop exits
353       if (!DivBlockLoop)
354         continue;
355 
356       const auto *BlockLoop = LI.getLoopFor(SuccBlock);
357       if (BlockLoop && DivBlockLoop->contains(BlockLoop))
358         continue;
359       DivDesc->LoopDivBlocks.insert(SuccBlock);
360       LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
361                         << SuccBlock->getName() << "\n");
362     }
363 
364     // propagate definitions at the immediate successors of the node in RPO
365     for (; BlockIdx >= FloorIdx; --BlockIdx) {
366       LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
367 
368       // Any label available here
369       const auto *Label = BlockLabels[BlockIdx];
370       if (!Label)
371         continue;
372 
373       // Ok. Get the block
374       const auto *Block = LoopPOT.getBlockAt(BlockIdx);
375       LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
376 
377       auto *BlockLoop = LI.getLoopFor(Block);
378       bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
379       bool CausedJoin = false;
380       int LoweredFloorIdx = FloorIdx;
381       if (IsLoopHeader) {
382         // Disconnect from immediate successors and propagate directly to loop
383         // exits.
384         SmallVector<BasicBlock *, 4> BlockLoopExits;
385         BlockLoop->getExitBlocks(BlockLoopExits);
386 
387         bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
388         for (const auto *BlockLoopExit : BlockLoopExits) {
389           CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
390           LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
391                                           LoopPOT.getIndexOf(*BlockLoopExit));
392         }
393       } else {
394         // Acyclic successor case
395         for (const auto *SuccBlock : successors(Block)) {
396           CausedJoin |= visitEdge(*SuccBlock, *Label);
397           LoweredFloorIdx =
398               std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
399         }
400       }
401 
402       // Floor update
403       if (CausedJoin) {
404         // 1. Different labels pushed to successors
405         FloorIdx = LoweredFloorIdx;
406       } else if (FloorLabel != Label) {
407         // 2. No join caused BUT we pushed a label that is different than the
408         // last pushed label
409         FloorIdx = LoweredFloorIdx;
410         FloorLabel = Label;
411       }
412     }
413 
414     LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
415 
416     return std::move(DivDesc);
417   }
418 };
419 
420 #ifndef NDEBUG
printBlockSet(ConstBlockSet & Blocks,raw_ostream & Out)421 static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
422   Out << "[";
423   ListSeparator LS;
424   for (const auto *BB : Blocks)
425     Out << LS << BB->getName();
426   Out << "]";
427 }
428 #endif
429 
430 const ControlDivergenceDesc &
getJoinBlocks(const Instruction & Term)431 SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
432   // trivial case
433   if (Term.getNumSuccessors() <= 1) {
434     return EmptyDivergenceDesc;
435   }
436 
437   // already available in cache?
438   auto ItCached = CachedControlDivDescs.find(&Term);
439   if (ItCached != CachedControlDivDescs.end())
440     return *ItCached->second;
441 
442   // compute all join points
443   // Special handling of divergent loop exits is not needed for LCSSA
444   const auto &TermBlock = *Term.getParent();
445   DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
446   auto DivDesc = Propagator.computeJoinPoints();
447 
448   LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
449              dbgs() << "JoinDivBlocks: ";
450              printBlockSet(DivDesc->JoinDivBlocks, dbgs());
451              dbgs() << "\nLoopDivBlocks: ";
452              printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
453 
454   auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
455   assert(ItInserted.second);
456   return *ItInserted.first->second;
457 }
458 
459 } // namespace llvm
460