xref: /llvm-project/llvm/lib/Transforms/Coroutines/MaterializationUtils.cpp (revision 6292a808b3524d9ba6f4ce55bc5b9e547b088dd8)
1 //===- MaterializationUtils.cpp - Builds and manipulates coroutine frame
2 //-------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 // This file contains classes used to materialize insts after suspends points.
10 //===----------------------------------------------------------------------===//
11 
12 #include "llvm/Transforms/Coroutines/MaterializationUtils.h"
13 #include "CoroInternal.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/IR/Dominators.h"
16 #include "llvm/IR/InstIterator.h"
17 #include "llvm/IR/Instruction.h"
18 #include "llvm/IR/ModuleSlotTracker.h"
19 #include "llvm/Transforms/Coroutines/SpillUtils.h"
20 #include <deque>
21 
22 using namespace llvm;
23 
24 using namespace coro;
25 
26 // The "coro-suspend-crossing" flag is very noisy. There is another debug type,
27 // "coro-frame", which results in leaner debug spew.
28 #define DEBUG_TYPE "coro-suspend-crossing"
29 
30 namespace {
31 
32 // RematGraph is used to construct a DAG for rematerializable instructions
33 // When the constructor is invoked with a candidate instruction (which is
34 // materializable) it builds a DAG of materializable instructions from that
35 // point.
36 // Typically, for each instruction identified as re-materializable across a
37 // suspend point, a RematGraph will be created.
38 struct RematGraph {
39   // Each RematNode in the graph contains the edges to instructions providing
40   // operands in the current node.
41   struct RematNode {
42     Instruction *Node;
43     SmallVector<RematNode *> Operands;
44     RematNode() = default;
45     RematNode(Instruction *V) : Node(V) {}
46   };
47 
48   RematNode *EntryNode;
49   using RematNodeMap =
50       SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
51   RematNodeMap Remats;
52   const std::function<bool(Instruction &)> &MaterializableCallback;
53   SuspendCrossingInfo &Checker;
54 
55   RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
56              Instruction *I, SuspendCrossingInfo &Checker)
57       : MaterializableCallback(MaterializableCallback), Checker(Checker) {
58     std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
59     EntryNode = FirstNode.get();
60     std::deque<std::unique_ptr<RematNode>> WorkList;
61     addNode(std::move(FirstNode), WorkList, cast<User>(I));
62     while (WorkList.size()) {
63       std::unique_ptr<RematNode> N = std::move(WorkList.front());
64       WorkList.pop_front();
65       addNode(std::move(N), WorkList, cast<User>(I));
66     }
67   }
68 
69   void addNode(std::unique_ptr<RematNode> NUPtr,
70                std::deque<std::unique_ptr<RematNode>> &WorkList,
71                User *FirstUse) {
72     RematNode *N = NUPtr.get();
73     if (Remats.count(N->Node))
74       return;
75 
76     // We haven't see this node yet - add to the list
77     Remats[N->Node] = std::move(NUPtr);
78     for (auto &Def : N->Node->operands()) {
79       Instruction *D = dyn_cast<Instruction>(Def.get());
80       if (!D || !MaterializableCallback(*D) ||
81           !Checker.isDefinitionAcrossSuspend(*D, FirstUse))
82         continue;
83 
84       if (auto It = Remats.find(D); It != Remats.end()) {
85         // Already have this in the graph
86         N->Operands.push_back(It->second.get());
87         continue;
88       }
89 
90       bool NoMatch = true;
91       for (auto &I : WorkList) {
92         if (I->Node == D) {
93           NoMatch = false;
94           N->Operands.push_back(I.get());
95           break;
96         }
97       }
98       if (NoMatch) {
99         // Create a new node
100         std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D);
101         N->Operands.push_back(ChildNode.get());
102         WorkList.push_back(std::move(ChildNode));
103       }
104     }
105   }
106 
107 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
108   static void dumpBasicBlockLabel(const BasicBlock *BB,
109                                   ModuleSlotTracker &MST) {
110     if (BB->hasName()) {
111       dbgs() << BB->getName();
112       return;
113     }
114 
115     dbgs() << MST.getLocalSlot(BB);
116   }
117 
118   void dump() const {
119     BasicBlock *BB = EntryNode->Node->getParent();
120     Function *F = BB->getParent();
121 
122     ModuleSlotTracker MST(F->getParent());
123     MST.incorporateFunction(*F);
124 
125     dbgs() << "Entry (";
126     dumpBasicBlockLabel(BB, MST);
127     dbgs() << ") : " << *EntryNode->Node << "\n";
128     for (auto &E : Remats) {
129       dbgs() << *(E.first) << "\n";
130       for (RematNode *U : E.second->Operands)
131         dbgs() << "  " << *U->Node << "\n";
132     }
133   }
134 #endif
135 };
136 
137 } // namespace
138 
139 namespace llvm {
140 template <> struct GraphTraits<RematGraph *> {
141   using NodeRef = RematGraph::RematNode *;
142   using ChildIteratorType = RematGraph::RematNode **;
143 
144   static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; }
145   static ChildIteratorType child_begin(NodeRef N) {
146     return N->Operands.begin();
147   }
148   static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); }
149 };
150 
151 } // end namespace llvm
152 
153 // For each instruction identified as materializable across the suspend point,
154 // and its associated DAG of other rematerializable instructions,
155 // recreate the DAG of instructions after the suspend point.
156 static void rewriteMaterializableInstructions(
157     const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8>
158         &AllRemats) {
159   // This has to be done in 2 phases
160   // Do the remats and record the required defs to be replaced in the
161   // original use instructions
162   // Once all the remats are complete, replace the uses in the final
163   // instructions with the new defs
164   typedef struct {
165     Instruction *Use;
166     Instruction *Def;
167     Instruction *Remat;
168   } ProcessNode;
169 
170   SmallVector<ProcessNode> FinalInstructionsToProcess;
171 
172   for (const auto &E : AllRemats) {
173     Instruction *Use = E.first;
174     Instruction *CurrentMaterialization = nullptr;
175     RematGraph *RG = E.second.get();
176     ReversePostOrderTraversal<RematGraph *> RPOT(RG);
177     SmallVector<Instruction *> InstructionsToProcess;
178 
179     // If the target use is actually a suspend instruction then we have to
180     // insert the remats into the end of the predecessor (there should only be
181     // one). This is so that suspend blocks always have the suspend instruction
182     // as the first instruction.
183     BasicBlock::iterator InsertPoint = Use->getParent()->getFirstInsertionPt();
184     if (isa<AnyCoroSuspendInst>(Use)) {
185       BasicBlock *SuspendPredecessorBlock =
186           Use->getParent()->getSinglePredecessor();
187       assert(SuspendPredecessorBlock && "malformed coro suspend instruction");
188       InsertPoint = SuspendPredecessorBlock->getTerminator()->getIterator();
189     }
190 
191     // Note: skip the first instruction as this is the actual use that we're
192     // rematerializing everything for.
193     auto I = RPOT.begin();
194     ++I;
195     for (; I != RPOT.end(); ++I) {
196       Instruction *D = (*I)->Node;
197       CurrentMaterialization = D->clone();
198       CurrentMaterialization->setName(D->getName());
199       CurrentMaterialization->insertBefore(InsertPoint);
200       InsertPoint = CurrentMaterialization->getIterator();
201 
202       // Replace all uses of Def in the instructions being added as part of this
203       // rematerialization group
204       for (auto &I : InstructionsToProcess)
205         I->replaceUsesOfWith(D, CurrentMaterialization);
206 
207       // Don't replace the final use at this point as this can cause problems
208       // for other materializations. Instead, for any final use that uses a
209       // define that's being rematerialized, record the replace values
210       for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i)
211         if (Use->getOperand(i) == D) // Is this operand pointing to oldval?
212           FinalInstructionsToProcess.push_back(
213               {Use, D, CurrentMaterialization});
214 
215       InstructionsToProcess.push_back(CurrentMaterialization);
216     }
217   }
218 
219   // Finally, replace the uses with the defines that we've just rematerialized
220   for (auto &R : FinalInstructionsToProcess) {
221     if (auto *PN = dyn_cast<PHINode>(R.Use)) {
222       assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
223                                                 "values in the PHINode");
224       PN->replaceAllUsesWith(R.Remat);
225       PN->eraseFromParent();
226       continue;
227     }
228     R.Use->replaceUsesOfWith(R.Def, R.Remat);
229   }
230 }
231 
232 /// Default materializable callback
233 // Check for instructions that we can recreate on resume as opposed to spill
234 // the result into a coroutine frame.
235 bool llvm::coro::defaultMaterializable(Instruction &V) {
236   return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
237           isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
238 }
239 
240 bool llvm::coro::isTriviallyMaterializable(Instruction &V) {
241   return defaultMaterializable(V);
242 }
243 
244 #ifndef NDEBUG
245 static void dumpRemats(
246     StringRef Title,
247     const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) {
248   dbgs() << "------------- " << Title << "--------------\n";
249   for (const auto &E : RM) {
250     E.second->dump();
251     dbgs() << "--\n";
252   }
253 }
254 #endif
255 
256 void coro::doRematerializations(
257     Function &F, SuspendCrossingInfo &Checker,
258     std::function<bool(Instruction &)> IsMaterializable) {
259   if (F.hasOptNone())
260     return;
261 
262   coro::SpillInfo Spills;
263 
264   // See if there are materializable instructions across suspend points
265   // We record these as the starting point to also identify materializable
266   // defs of uses in these operations
267   for (Instruction &I : instructions(F)) {
268     if (!IsMaterializable(I))
269       continue;
270     for (User *U : I.users())
271       if (Checker.isDefinitionAcrossSuspend(I, U))
272         Spills[&I].push_back(cast<Instruction>(U));
273   }
274 
275   // Process each of the identified rematerializable instructions
276   // and add predecessor instructions that can also be rematerialized.
277   // This is actually a graph of instructions since we could potentially
278   // have multiple uses of a def in the set of predecessor instructions.
279   // The approach here is to maintain a graph of instructions for each bottom
280   // level instruction - where we have a unique set of instructions (nodes)
281   // and edges between them. We then walk the graph in reverse post-dominator
282   // order to insert them past the suspend point, but ensure that ordering is
283   // correct. We also rely on CSE removing duplicate defs for remats of
284   // different instructions with a def in common (rather than maintaining more
285   // complex graphs for each suspend point)
286 
287   // We can do this by adding new nodes to the list for each suspend
288   // point. Then using standard GraphTraits to give a reverse post-order
289   // traversal when we insert the nodes after the suspend
290   SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats;
291   for (auto &E : Spills) {
292     for (Instruction *U : E.second) {
293       // Don't process a user twice (this can happen if the instruction uses
294       // more than one rematerializable def)
295       if (AllRemats.count(U))
296         continue;
297 
298       // Constructor creates the whole RematGraph for the given Use
299       auto RematUPtr =
300           std::make_unique<RematGraph>(IsMaterializable, U, Checker);
301 
302       LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
303                  ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
304                  for (auto I = RPOT.begin(); I != RPOT.end();
305                       ++I) { (*I)->Node->dump(); } dbgs()
306                  << "\n";);
307 
308       AllRemats[U] = std::move(RematUPtr);
309     }
310   }
311 
312   // Rewrite materializable instructions to be materialized at the use
313   // point.
314   LLVM_DEBUG(dumpRemats("Materializations", AllRemats));
315   rewriteMaterializableInstructions(AllRemats);
316 }
317