xref: /llvm-project/llvm/lib/Transforms/Coroutines/MaterializationUtils.cpp (revision e82fcda1475b6708b7d314fd7a54e551306d5739)
1 //===- MaterializationUtils.cpp - Builds and manipulates coroutine frame
2 //-------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 // This file contains classes used to materialize insts after suspends points.
10 //===----------------------------------------------------------------------===//
11 
12 #include "llvm/Transforms/Coroutines/MaterializationUtils.h"
13 #include "CoroInternal.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/IR/Dominators.h"
16 #include "llvm/IR/InstIterator.h"
17 #include "llvm/IR/Instruction.h"
18 #include "llvm/Transforms/Coroutines/SpillUtils.h"
19 #include <deque>
20 
21 using namespace llvm;
22 
23 using namespace coro;
24 
25 // The "coro-suspend-crossing" flag is very noisy. There is another debug type,
26 // "coro-frame", which results in leaner debug spew.
27 #define DEBUG_TYPE "coro-suspend-crossing"
28 
29 namespace {
30 
31 // RematGraph is used to construct a DAG for rematerializable instructions
32 // When the constructor is invoked with a candidate instruction (which is
33 // materializable) it builds a DAG of materializable instructions from that
34 // point.
35 // Typically, for each instruction identified as re-materializable across a
36 // suspend point, a RematGraph will be created.
37 struct RematGraph {
38   // Each RematNode in the graph contains the edges to instructions providing
39   // operands in the current node.
40   struct RematNode {
41     Instruction *Node;
42     SmallVector<RematNode *> Operands;
43     RematNode() = default;
44     RematNode(Instruction *V) : Node(V) {}
45   };
46 
47   RematNode *EntryNode;
48   using RematNodeMap =
49       SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
50   RematNodeMap Remats;
51   const std::function<bool(Instruction &)> &MaterializableCallback;
52   SuspendCrossingInfo &Checker;
53 
54   RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
55              Instruction *I, SuspendCrossingInfo &Checker)
56       : MaterializableCallback(MaterializableCallback), Checker(Checker) {
57     std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
58     EntryNode = FirstNode.get();
59     std::deque<std::unique_ptr<RematNode>> WorkList;
60     addNode(std::move(FirstNode), WorkList, cast<User>(I));
61     while (WorkList.size()) {
62       std::unique_ptr<RematNode> N = std::move(WorkList.front());
63       WorkList.pop_front();
64       addNode(std::move(N), WorkList, cast<User>(I));
65     }
66   }
67 
68   void addNode(std::unique_ptr<RematNode> NUPtr,
69                std::deque<std::unique_ptr<RematNode>> &WorkList,
70                User *FirstUse) {
71     RematNode *N = NUPtr.get();
72     if (Remats.count(N->Node))
73       return;
74 
75     // We haven't see this node yet - add to the list
76     Remats[N->Node] = std::move(NUPtr);
77     for (auto &Def : N->Node->operands()) {
78       Instruction *D = dyn_cast<Instruction>(Def.get());
79       if (!D || !MaterializableCallback(*D) ||
80           !Checker.isDefinitionAcrossSuspend(*D, FirstUse))
81         continue;
82 
83       if (Remats.count(D)) {
84         // Already have this in the graph
85         N->Operands.push_back(Remats[D].get());
86         continue;
87       }
88 
89       bool NoMatch = true;
90       for (auto &I : WorkList) {
91         if (I->Node == D) {
92           NoMatch = false;
93           N->Operands.push_back(I.get());
94           break;
95         }
96       }
97       if (NoMatch) {
98         // Create a new node
99         std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D);
100         N->Operands.push_back(ChildNode.get());
101         WorkList.push_back(std::move(ChildNode));
102       }
103     }
104   }
105 
106 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
107   static std::string getBasicBlockLabel(const BasicBlock *BB) {
108     if (BB->hasName())
109       return BB->getName().str();
110 
111     std::string S;
112     raw_string_ostream OS(S);
113     BB->printAsOperand(OS, false);
114     return OS.str().substr(1);
115   }
116 
117   void dump() const {
118     dbgs() << "Entry (";
119     dbgs() << getBasicBlockLabel(EntryNode->Node->getParent());
120     dbgs() << ") : " << *EntryNode->Node << "\n";
121     for (auto &E : Remats) {
122       dbgs() << *(E.first) << "\n";
123       for (RematNode *U : E.second->Operands)
124         dbgs() << "  " << *U->Node << "\n";
125     }
126   }
127 #endif
128 };
129 
130 } // namespace
131 
132 namespace llvm {
133 template <> struct GraphTraits<RematGraph *> {
134   using NodeRef = RematGraph::RematNode *;
135   using ChildIteratorType = RematGraph::RematNode **;
136 
137   static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; }
138   static ChildIteratorType child_begin(NodeRef N) {
139     return N->Operands.begin();
140   }
141   static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); }
142 };
143 
144 } // end namespace llvm
145 
146 // For each instruction identified as materializable across the suspend point,
147 // and its associated DAG of other rematerializable instructions,
148 // recreate the DAG of instructions after the suspend point.
149 static void rewriteMaterializableInstructions(
150     const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8>
151         &AllRemats) {
152   // This has to be done in 2 phases
153   // Do the remats and record the required defs to be replaced in the
154   // original use instructions
155   // Once all the remats are complete, replace the uses in the final
156   // instructions with the new defs
157   typedef struct {
158     Instruction *Use;
159     Instruction *Def;
160     Instruction *Remat;
161   } ProcessNode;
162 
163   SmallVector<ProcessNode> FinalInstructionsToProcess;
164 
165   for (const auto &E : AllRemats) {
166     Instruction *Use = E.first;
167     Instruction *CurrentMaterialization = nullptr;
168     RematGraph *RG = E.second.get();
169     ReversePostOrderTraversal<RematGraph *> RPOT(RG);
170     SmallVector<Instruction *> InstructionsToProcess;
171 
172     // If the target use is actually a suspend instruction then we have to
173     // insert the remats into the end of the predecessor (there should only be
174     // one). This is so that suspend blocks always have the suspend instruction
175     // as the first instruction.
176     auto InsertPoint = &*Use->getParent()->getFirstInsertionPt();
177     if (isa<AnyCoroSuspendInst>(Use)) {
178       BasicBlock *SuspendPredecessorBlock =
179           Use->getParent()->getSinglePredecessor();
180       assert(SuspendPredecessorBlock && "malformed coro suspend instruction");
181       InsertPoint = SuspendPredecessorBlock->getTerminator();
182     }
183 
184     // Note: skip the first instruction as this is the actual use that we're
185     // rematerializing everything for.
186     auto I = RPOT.begin();
187     ++I;
188     for (; I != RPOT.end(); ++I) {
189       Instruction *D = (*I)->Node;
190       CurrentMaterialization = D->clone();
191       CurrentMaterialization->setName(D->getName());
192       CurrentMaterialization->insertBefore(InsertPoint);
193       InsertPoint = CurrentMaterialization;
194 
195       // Replace all uses of Def in the instructions being added as part of this
196       // rematerialization group
197       for (auto &I : InstructionsToProcess)
198         I->replaceUsesOfWith(D, CurrentMaterialization);
199 
200       // Don't replace the final use at this point as this can cause problems
201       // for other materializations. Instead, for any final use that uses a
202       // define that's being rematerialized, record the replace values
203       for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i)
204         if (Use->getOperand(i) == D) // Is this operand pointing to oldval?
205           FinalInstructionsToProcess.push_back(
206               {Use, D, CurrentMaterialization});
207 
208       InstructionsToProcess.push_back(CurrentMaterialization);
209     }
210   }
211 
212   // Finally, replace the uses with the defines that we've just rematerialized
213   for (auto &R : FinalInstructionsToProcess) {
214     if (auto *PN = dyn_cast<PHINode>(R.Use)) {
215       assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
216                                                 "values in the PHINode");
217       PN->replaceAllUsesWith(R.Remat);
218       PN->eraseFromParent();
219       continue;
220     }
221     R.Use->replaceUsesOfWith(R.Def, R.Remat);
222   }
223 }
224 
225 /// Default materializable callback
226 // Check for instructions that we can recreate on resume as opposed to spill
227 // the result into a coroutine frame.
228 bool llvm::coro::defaultMaterializable(Instruction &V) {
229   return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
230           isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
231 }
232 
233 bool llvm::coro::isTriviallyMaterializable(Instruction &V) {
234   return defaultMaterializable(V);
235 }
236 
237 #ifndef NDEBUG
238 static void dumpRemats(
239     StringRef Title,
240     const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) {
241   dbgs() << "------------- " << Title << "--------------\n";
242   for (const auto &E : RM) {
243     E.second->dump();
244     dbgs() << "--\n";
245   }
246 }
247 #endif
248 
249 void coro::doRematerializations(
250     Function &F, SuspendCrossingInfo &Checker,
251     std::function<bool(Instruction &)> IsMaterializable) {
252   if (F.hasOptNone())
253     return;
254 
255   coro::SpillInfo Spills;
256 
257   // See if there are materializable instructions across suspend points
258   // We record these as the starting point to also identify materializable
259   // defs of uses in these operations
260   for (Instruction &I : instructions(F)) {
261     if (!IsMaterializable(I))
262       continue;
263     for (User *U : I.users())
264       if (Checker.isDefinitionAcrossSuspend(I, U))
265         Spills[&I].push_back(cast<Instruction>(U));
266   }
267 
268   // Process each of the identified rematerializable instructions
269   // and add predecessor instructions that can also be rematerialized.
270   // This is actually a graph of instructions since we could potentially
271   // have multiple uses of a def in the set of predecessor instructions.
272   // The approach here is to maintain a graph of instructions for each bottom
273   // level instruction - where we have a unique set of instructions (nodes)
274   // and edges between them. We then walk the graph in reverse post-dominator
275   // order to insert them past the suspend point, but ensure that ordering is
276   // correct. We also rely on CSE removing duplicate defs for remats of
277   // different instructions with a def in common (rather than maintaining more
278   // complex graphs for each suspend point)
279 
280   // We can do this by adding new nodes to the list for each suspend
281   // point. Then using standard GraphTraits to give a reverse post-order
282   // traversal when we insert the nodes after the suspend
283   SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats;
284   for (auto &E : Spills) {
285     for (Instruction *U : E.second) {
286       // Don't process a user twice (this can happen if the instruction uses
287       // more than one rematerializable def)
288       if (AllRemats.count(U))
289         continue;
290 
291       // Constructor creates the whole RematGraph for the given Use
292       auto RematUPtr =
293           std::make_unique<RematGraph>(IsMaterializable, U, Checker);
294 
295       LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
296                  ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
297                  for (auto I = RPOT.begin(); I != RPOT.end();
298                       ++I) { (*I)->Node->dump(); } dbgs()
299                  << "\n";);
300 
301       AllRemats[U] = std::move(RematUPtr);
302     }
303   }
304 
305   // Rewrite materializable instructions to be materialized at the use
306   // point.
307   LLVM_DEBUG(dumpRemats("Materializations", AllRemats));
308   rewriteMaterializableInstructions(AllRemats);
309 }
310