1 //===- MaterializationUtils.cpp - Builds and manipulates coroutine frame 2 //-------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // This file contains classes used to materialize insts after suspends points. 10 //===----------------------------------------------------------------------===// 11 12 #include "llvm/Transforms/Coroutines/MaterializationUtils.h" 13 #include "CoroInternal.h" 14 #include "llvm/ADT/PostOrderIterator.h" 15 #include "llvm/IR/Dominators.h" 16 #include "llvm/IR/InstIterator.h" 17 #include "llvm/IR/Instruction.h" 18 #include "llvm/IR/ModuleSlotTracker.h" 19 #include "llvm/Transforms/Coroutines/SpillUtils.h" 20 #include <deque> 21 22 using namespace llvm; 23 24 using namespace coro; 25 26 // The "coro-suspend-crossing" flag is very noisy. There is another debug type, 27 // "coro-frame", which results in leaner debug spew. 28 #define DEBUG_TYPE "coro-suspend-crossing" 29 30 namespace { 31 32 // RematGraph is used to construct a DAG for rematerializable instructions 33 // When the constructor is invoked with a candidate instruction (which is 34 // materializable) it builds a DAG of materializable instructions from that 35 // point. 36 // Typically, for each instruction identified as re-materializable across a 37 // suspend point, a RematGraph will be created. 38 struct RematGraph { 39 // Each RematNode in the graph contains the edges to instructions providing 40 // operands in the current node. 41 struct RematNode { 42 Instruction *Node; 43 SmallVector<RematNode *> Operands; 44 RematNode() = default; 45 RematNode(Instruction *V) : Node(V) {} 46 }; 47 48 RematNode *EntryNode; 49 using RematNodeMap = 50 SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>; 51 RematNodeMap Remats; 52 const std::function<bool(Instruction &)> &MaterializableCallback; 53 SuspendCrossingInfo &Checker; 54 55 RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback, 56 Instruction *I, SuspendCrossingInfo &Checker) 57 : MaterializableCallback(MaterializableCallback), Checker(Checker) { 58 std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I); 59 EntryNode = FirstNode.get(); 60 std::deque<std::unique_ptr<RematNode>> WorkList; 61 addNode(std::move(FirstNode), WorkList, cast<User>(I)); 62 while (WorkList.size()) { 63 std::unique_ptr<RematNode> N = std::move(WorkList.front()); 64 WorkList.pop_front(); 65 addNode(std::move(N), WorkList, cast<User>(I)); 66 } 67 } 68 69 void addNode(std::unique_ptr<RematNode> NUPtr, 70 std::deque<std::unique_ptr<RematNode>> &WorkList, 71 User *FirstUse) { 72 RematNode *N = NUPtr.get(); 73 if (Remats.count(N->Node)) 74 return; 75 76 // We haven't see this node yet - add to the list 77 Remats[N->Node] = std::move(NUPtr); 78 for (auto &Def : N->Node->operands()) { 79 Instruction *D = dyn_cast<Instruction>(Def.get()); 80 if (!D || !MaterializableCallback(*D) || 81 !Checker.isDefinitionAcrossSuspend(*D, FirstUse)) 82 continue; 83 84 if (auto It = Remats.find(D); It != Remats.end()) { 85 // Already have this in the graph 86 N->Operands.push_back(It->second.get()); 87 continue; 88 } 89 90 bool NoMatch = true; 91 for (auto &I : WorkList) { 92 if (I->Node == D) { 93 NoMatch = false; 94 N->Operands.push_back(I.get()); 95 break; 96 } 97 } 98 if (NoMatch) { 99 // Create a new node 100 std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D); 101 N->Operands.push_back(ChildNode.get()); 102 WorkList.push_back(std::move(ChildNode)); 103 } 104 } 105 } 106 107 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 108 static void dumpBasicBlockLabel(const BasicBlock *BB, 109 ModuleSlotTracker &MST) { 110 if (BB->hasName()) { 111 dbgs() << BB->getName(); 112 return; 113 } 114 115 dbgs() << MST.getLocalSlot(BB); 116 } 117 118 void dump() const { 119 BasicBlock *BB = EntryNode->Node->getParent(); 120 Function *F = BB->getParent(); 121 122 ModuleSlotTracker MST(F->getParent()); 123 MST.incorporateFunction(*F); 124 125 dbgs() << "Entry ("; 126 dumpBasicBlockLabel(BB, MST); 127 dbgs() << ") : " << *EntryNode->Node << "\n"; 128 for (auto &E : Remats) { 129 dbgs() << *(E.first) << "\n"; 130 for (RematNode *U : E.second->Operands) 131 dbgs() << " " << *U->Node << "\n"; 132 } 133 } 134 #endif 135 }; 136 137 } // namespace 138 139 namespace llvm { 140 template <> struct GraphTraits<RematGraph *> { 141 using NodeRef = RematGraph::RematNode *; 142 using ChildIteratorType = RematGraph::RematNode **; 143 144 static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; } 145 static ChildIteratorType child_begin(NodeRef N) { 146 return N->Operands.begin(); 147 } 148 static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); } 149 }; 150 151 } // end namespace llvm 152 153 // For each instruction identified as materializable across the suspend point, 154 // and its associated DAG of other rematerializable instructions, 155 // recreate the DAG of instructions after the suspend point. 156 static void rewriteMaterializableInstructions( 157 const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> 158 &AllRemats) { 159 // This has to be done in 2 phases 160 // Do the remats and record the required defs to be replaced in the 161 // original use instructions 162 // Once all the remats are complete, replace the uses in the final 163 // instructions with the new defs 164 typedef struct { 165 Instruction *Use; 166 Instruction *Def; 167 Instruction *Remat; 168 } ProcessNode; 169 170 SmallVector<ProcessNode> FinalInstructionsToProcess; 171 172 for (const auto &E : AllRemats) { 173 Instruction *Use = E.first; 174 Instruction *CurrentMaterialization = nullptr; 175 RematGraph *RG = E.second.get(); 176 ReversePostOrderTraversal<RematGraph *> RPOT(RG); 177 SmallVector<Instruction *> InstructionsToProcess; 178 179 // If the target use is actually a suspend instruction then we have to 180 // insert the remats into the end of the predecessor (there should only be 181 // one). This is so that suspend blocks always have the suspend instruction 182 // as the first instruction. 183 BasicBlock::iterator InsertPoint = Use->getParent()->getFirstInsertionPt(); 184 if (isa<AnyCoroSuspendInst>(Use)) { 185 BasicBlock *SuspendPredecessorBlock = 186 Use->getParent()->getSinglePredecessor(); 187 assert(SuspendPredecessorBlock && "malformed coro suspend instruction"); 188 InsertPoint = SuspendPredecessorBlock->getTerminator()->getIterator(); 189 } 190 191 // Note: skip the first instruction as this is the actual use that we're 192 // rematerializing everything for. 193 auto I = RPOT.begin(); 194 ++I; 195 for (; I != RPOT.end(); ++I) { 196 Instruction *D = (*I)->Node; 197 CurrentMaterialization = D->clone(); 198 CurrentMaterialization->setName(D->getName()); 199 CurrentMaterialization->insertBefore(InsertPoint); 200 InsertPoint = CurrentMaterialization->getIterator(); 201 202 // Replace all uses of Def in the instructions being added as part of this 203 // rematerialization group 204 for (auto &I : InstructionsToProcess) 205 I->replaceUsesOfWith(D, CurrentMaterialization); 206 207 // Don't replace the final use at this point as this can cause problems 208 // for other materializations. Instead, for any final use that uses a 209 // define that's being rematerialized, record the replace values 210 for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i) 211 if (Use->getOperand(i) == D) // Is this operand pointing to oldval? 212 FinalInstructionsToProcess.push_back( 213 {Use, D, CurrentMaterialization}); 214 215 InstructionsToProcess.push_back(CurrentMaterialization); 216 } 217 } 218 219 // Finally, replace the uses with the defines that we've just rematerialized 220 for (auto &R : FinalInstructionsToProcess) { 221 if (auto *PN = dyn_cast<PHINode>(R.Use)) { 222 assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming " 223 "values in the PHINode"); 224 PN->replaceAllUsesWith(R.Remat); 225 PN->eraseFromParent(); 226 continue; 227 } 228 R.Use->replaceUsesOfWith(R.Def, R.Remat); 229 } 230 } 231 232 /// Default materializable callback 233 // Check for instructions that we can recreate on resume as opposed to spill 234 // the result into a coroutine frame. 235 bool llvm::coro::defaultMaterializable(Instruction &V) { 236 return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) || 237 isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V)); 238 } 239 240 bool llvm::coro::isTriviallyMaterializable(Instruction &V) { 241 return defaultMaterializable(V); 242 } 243 244 #ifndef NDEBUG 245 static void dumpRemats( 246 StringRef Title, 247 const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) { 248 dbgs() << "------------- " << Title << "--------------\n"; 249 for (const auto &E : RM) { 250 E.second->dump(); 251 dbgs() << "--\n"; 252 } 253 } 254 #endif 255 256 void coro::doRematerializations( 257 Function &F, SuspendCrossingInfo &Checker, 258 std::function<bool(Instruction &)> IsMaterializable) { 259 if (F.hasOptNone()) 260 return; 261 262 coro::SpillInfo Spills; 263 264 // See if there are materializable instructions across suspend points 265 // We record these as the starting point to also identify materializable 266 // defs of uses in these operations 267 for (Instruction &I : instructions(F)) { 268 if (!IsMaterializable(I)) 269 continue; 270 for (User *U : I.users()) 271 if (Checker.isDefinitionAcrossSuspend(I, U)) 272 Spills[&I].push_back(cast<Instruction>(U)); 273 } 274 275 // Process each of the identified rematerializable instructions 276 // and add predecessor instructions that can also be rematerialized. 277 // This is actually a graph of instructions since we could potentially 278 // have multiple uses of a def in the set of predecessor instructions. 279 // The approach here is to maintain a graph of instructions for each bottom 280 // level instruction - where we have a unique set of instructions (nodes) 281 // and edges between them. We then walk the graph in reverse post-dominator 282 // order to insert them past the suspend point, but ensure that ordering is 283 // correct. We also rely on CSE removing duplicate defs for remats of 284 // different instructions with a def in common (rather than maintaining more 285 // complex graphs for each suspend point) 286 287 // We can do this by adding new nodes to the list for each suspend 288 // point. Then using standard GraphTraits to give a reverse post-order 289 // traversal when we insert the nodes after the suspend 290 SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats; 291 for (auto &E : Spills) { 292 for (Instruction *U : E.second) { 293 // Don't process a user twice (this can happen if the instruction uses 294 // more than one rematerializable def) 295 if (AllRemats.count(U)) 296 continue; 297 298 // Constructor creates the whole RematGraph for the given Use 299 auto RematUPtr = 300 std::make_unique<RematGraph>(IsMaterializable, U, Checker); 301 302 LLVM_DEBUG(dbgs() << "***** Next remat group *****\n"; 303 ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get()); 304 for (auto I = RPOT.begin(); I != RPOT.end(); 305 ++I) { (*I)->Node->dump(); } dbgs() 306 << "\n";); 307 308 AllRemats[U] = std::move(RematUPtr); 309 } 310 } 311 312 // Rewrite materializable instructions to be materialized at the use 313 // point. 314 LLVM_DEBUG(dumpRemats("Materializations", AllRemats)); 315 rewriteMaterializableInstructions(AllRemats); 316 } 317