xref: /llvm-project/llvm/lib/Transforms/Coroutines/MaterializationUtils.cpp (revision 6292a808b3524d9ba6f4ce55bc5b9e547b088dd8)
12670565aSTyler Nowicki //===- MaterializationUtils.cpp - Builds and manipulates coroutine frame
22670565aSTyler Nowicki //-------------===//
32670565aSTyler Nowicki //
42670565aSTyler Nowicki // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
52670565aSTyler Nowicki // See https://llvm.org/LICENSE.txt for license information.
62670565aSTyler Nowicki // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
72670565aSTyler Nowicki //
82670565aSTyler Nowicki //===----------------------------------------------------------------------===//
92670565aSTyler Nowicki // This file contains classes used to materialize insts after suspends points.
102670565aSTyler Nowicki //===----------------------------------------------------------------------===//
112670565aSTyler Nowicki 
12e82fcda1STyler Nowicki #include "llvm/Transforms/Coroutines/MaterializationUtils.h"
13e82fcda1STyler Nowicki #include "CoroInternal.h"
142670565aSTyler Nowicki #include "llvm/ADT/PostOrderIterator.h"
152670565aSTyler Nowicki #include "llvm/IR/Dominators.h"
162670565aSTyler Nowicki #include "llvm/IR/InstIterator.h"
172670565aSTyler Nowicki #include "llvm/IR/Instruction.h"
184db57ab9STyler Nowicki #include "llvm/IR/ModuleSlotTracker.h"
19e82fcda1STyler Nowicki #include "llvm/Transforms/Coroutines/SpillUtils.h"
202670565aSTyler Nowicki #include <deque>
212670565aSTyler Nowicki 
222670565aSTyler Nowicki using namespace llvm;
232670565aSTyler Nowicki 
242670565aSTyler Nowicki using namespace coro;
252670565aSTyler Nowicki 
262670565aSTyler Nowicki // The "coro-suspend-crossing" flag is very noisy. There is another debug type,
272670565aSTyler Nowicki // "coro-frame", which results in leaner debug spew.
282670565aSTyler Nowicki #define DEBUG_TYPE "coro-suspend-crossing"
292670565aSTyler Nowicki 
302670565aSTyler Nowicki namespace {
312670565aSTyler Nowicki 
322670565aSTyler Nowicki // RematGraph is used to construct a DAG for rematerializable instructions
332670565aSTyler Nowicki // When the constructor is invoked with a candidate instruction (which is
342670565aSTyler Nowicki // materializable) it builds a DAG of materializable instructions from that
352670565aSTyler Nowicki // point.
362670565aSTyler Nowicki // Typically, for each instruction identified as re-materializable across a
372670565aSTyler Nowicki // suspend point, a RematGraph will be created.
382670565aSTyler Nowicki struct RematGraph {
392670565aSTyler Nowicki   // Each RematNode in the graph contains the edges to instructions providing
402670565aSTyler Nowicki   // operands in the current node.
412670565aSTyler Nowicki   struct RematNode {
422670565aSTyler Nowicki     Instruction *Node;
432670565aSTyler Nowicki     SmallVector<RematNode *> Operands;
442670565aSTyler Nowicki     RematNode() = default;
452670565aSTyler Nowicki     RematNode(Instruction *V) : Node(V) {}
462670565aSTyler Nowicki   };
472670565aSTyler Nowicki 
482670565aSTyler Nowicki   RematNode *EntryNode;
492670565aSTyler Nowicki   using RematNodeMap =
502670565aSTyler Nowicki       SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
512670565aSTyler Nowicki   RematNodeMap Remats;
522670565aSTyler Nowicki   const std::function<bool(Instruction &)> &MaterializableCallback;
532670565aSTyler Nowicki   SuspendCrossingInfo &Checker;
542670565aSTyler Nowicki 
552670565aSTyler Nowicki   RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
562670565aSTyler Nowicki              Instruction *I, SuspendCrossingInfo &Checker)
572670565aSTyler Nowicki       : MaterializableCallback(MaterializableCallback), Checker(Checker) {
582670565aSTyler Nowicki     std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
592670565aSTyler Nowicki     EntryNode = FirstNode.get();
602670565aSTyler Nowicki     std::deque<std::unique_ptr<RematNode>> WorkList;
612670565aSTyler Nowicki     addNode(std::move(FirstNode), WorkList, cast<User>(I));
622670565aSTyler Nowicki     while (WorkList.size()) {
632670565aSTyler Nowicki       std::unique_ptr<RematNode> N = std::move(WorkList.front());
642670565aSTyler Nowicki       WorkList.pop_front();
652670565aSTyler Nowicki       addNode(std::move(N), WorkList, cast<User>(I));
662670565aSTyler Nowicki     }
672670565aSTyler Nowicki   }
682670565aSTyler Nowicki 
692670565aSTyler Nowicki   void addNode(std::unique_ptr<RematNode> NUPtr,
702670565aSTyler Nowicki                std::deque<std::unique_ptr<RematNode>> &WorkList,
712670565aSTyler Nowicki                User *FirstUse) {
722670565aSTyler Nowicki     RematNode *N = NUPtr.get();
732670565aSTyler Nowicki     if (Remats.count(N->Node))
742670565aSTyler Nowicki       return;
752670565aSTyler Nowicki 
762670565aSTyler Nowicki     // We haven't see this node yet - add to the list
772670565aSTyler Nowicki     Remats[N->Node] = std::move(NUPtr);
782670565aSTyler Nowicki     for (auto &Def : N->Node->operands()) {
792670565aSTyler Nowicki       Instruction *D = dyn_cast<Instruction>(Def.get());
802670565aSTyler Nowicki       if (!D || !MaterializableCallback(*D) ||
812670565aSTyler Nowicki           !Checker.isDefinitionAcrossSuspend(*D, FirstUse))
822670565aSTyler Nowicki         continue;
832670565aSTyler Nowicki 
846a5a795cSKazu Hirata       if (auto It = Remats.find(D); It != Remats.end()) {
852670565aSTyler Nowicki         // Already have this in the graph
866a5a795cSKazu Hirata         N->Operands.push_back(It->second.get());
872670565aSTyler Nowicki         continue;
882670565aSTyler Nowicki       }
892670565aSTyler Nowicki 
902670565aSTyler Nowicki       bool NoMatch = true;
912670565aSTyler Nowicki       for (auto &I : WorkList) {
922670565aSTyler Nowicki         if (I->Node == D) {
932670565aSTyler Nowicki           NoMatch = false;
942670565aSTyler Nowicki           N->Operands.push_back(I.get());
952670565aSTyler Nowicki           break;
962670565aSTyler Nowicki         }
972670565aSTyler Nowicki       }
982670565aSTyler Nowicki       if (NoMatch) {
992670565aSTyler Nowicki         // Create a new node
1002670565aSTyler Nowicki         std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D);
1012670565aSTyler Nowicki         N->Operands.push_back(ChildNode.get());
1022670565aSTyler Nowicki         WorkList.push_back(std::move(ChildNode));
1032670565aSTyler Nowicki       }
1042670565aSTyler Nowicki     }
1052670565aSTyler Nowicki   }
1062670565aSTyler Nowicki 
1072670565aSTyler Nowicki #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1084db57ab9STyler Nowicki   static void dumpBasicBlockLabel(const BasicBlock *BB,
1094db57ab9STyler Nowicki                                   ModuleSlotTracker &MST) {
1104db57ab9STyler Nowicki     if (BB->hasName()) {
1114db57ab9STyler Nowicki       dbgs() << BB->getName();
1124db57ab9STyler Nowicki       return;
1134db57ab9STyler Nowicki     }
1142670565aSTyler Nowicki 
1154db57ab9STyler Nowicki     dbgs() << MST.getLocalSlot(BB);
1162670565aSTyler Nowicki   }
1172670565aSTyler Nowicki 
1182670565aSTyler Nowicki   void dump() const {
1194db57ab9STyler Nowicki     BasicBlock *BB = EntryNode->Node->getParent();
1204db57ab9STyler Nowicki     Function *F = BB->getParent();
1214db57ab9STyler Nowicki 
1224db57ab9STyler Nowicki     ModuleSlotTracker MST(F->getParent());
1234db57ab9STyler Nowicki     MST.incorporateFunction(*F);
1244db57ab9STyler Nowicki 
1252670565aSTyler Nowicki     dbgs() << "Entry (";
1264db57ab9STyler Nowicki     dumpBasicBlockLabel(BB, MST);
1272670565aSTyler Nowicki     dbgs() << ") : " << *EntryNode->Node << "\n";
1282670565aSTyler Nowicki     for (auto &E : Remats) {
1292670565aSTyler Nowicki       dbgs() << *(E.first) << "\n";
1302670565aSTyler Nowicki       for (RematNode *U : E.second->Operands)
1312670565aSTyler Nowicki         dbgs() << "  " << *U->Node << "\n";
1322670565aSTyler Nowicki     }
1332670565aSTyler Nowicki   }
1342670565aSTyler Nowicki #endif
1352670565aSTyler Nowicki };
1362670565aSTyler Nowicki 
1372670565aSTyler Nowicki } // namespace
1382670565aSTyler Nowicki 
1392670565aSTyler Nowicki namespace llvm {
1402670565aSTyler Nowicki template <> struct GraphTraits<RematGraph *> {
1412670565aSTyler Nowicki   using NodeRef = RematGraph::RematNode *;
1422670565aSTyler Nowicki   using ChildIteratorType = RematGraph::RematNode **;
1432670565aSTyler Nowicki 
1442670565aSTyler Nowicki   static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; }
1452670565aSTyler Nowicki   static ChildIteratorType child_begin(NodeRef N) {
1462670565aSTyler Nowicki     return N->Operands.begin();
1472670565aSTyler Nowicki   }
1482670565aSTyler Nowicki   static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); }
1492670565aSTyler Nowicki };
1502670565aSTyler Nowicki 
1512670565aSTyler Nowicki } // end namespace llvm
1522670565aSTyler Nowicki 
1532670565aSTyler Nowicki // For each instruction identified as materializable across the suspend point,
1542670565aSTyler Nowicki // and its associated DAG of other rematerializable instructions,
1552670565aSTyler Nowicki // recreate the DAG of instructions after the suspend point.
1562670565aSTyler Nowicki static void rewriteMaterializableInstructions(
1572670565aSTyler Nowicki     const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8>
1582670565aSTyler Nowicki         &AllRemats) {
1592670565aSTyler Nowicki   // This has to be done in 2 phases
1602670565aSTyler Nowicki   // Do the remats and record the required defs to be replaced in the
1612670565aSTyler Nowicki   // original use instructions
1622670565aSTyler Nowicki   // Once all the remats are complete, replace the uses in the final
1632670565aSTyler Nowicki   // instructions with the new defs
1642670565aSTyler Nowicki   typedef struct {
1652670565aSTyler Nowicki     Instruction *Use;
1662670565aSTyler Nowicki     Instruction *Def;
1672670565aSTyler Nowicki     Instruction *Remat;
1682670565aSTyler Nowicki   } ProcessNode;
1692670565aSTyler Nowicki 
1702670565aSTyler Nowicki   SmallVector<ProcessNode> FinalInstructionsToProcess;
1712670565aSTyler Nowicki 
1722670565aSTyler Nowicki   for (const auto &E : AllRemats) {
1732670565aSTyler Nowicki     Instruction *Use = E.first;
1742670565aSTyler Nowicki     Instruction *CurrentMaterialization = nullptr;
1752670565aSTyler Nowicki     RematGraph *RG = E.second.get();
1762670565aSTyler Nowicki     ReversePostOrderTraversal<RematGraph *> RPOT(RG);
1772670565aSTyler Nowicki     SmallVector<Instruction *> InstructionsToProcess;
1782670565aSTyler Nowicki 
1792670565aSTyler Nowicki     // If the target use is actually a suspend instruction then we have to
1802670565aSTyler Nowicki     // insert the remats into the end of the predecessor (there should only be
1812670565aSTyler Nowicki     // one). This is so that suspend blocks always have the suspend instruction
1822670565aSTyler Nowicki     // as the first instruction.
183*6292a808SJeremy Morse     BasicBlock::iterator InsertPoint = Use->getParent()->getFirstInsertionPt();
1842670565aSTyler Nowicki     if (isa<AnyCoroSuspendInst>(Use)) {
1852670565aSTyler Nowicki       BasicBlock *SuspendPredecessorBlock =
1862670565aSTyler Nowicki           Use->getParent()->getSinglePredecessor();
1872670565aSTyler Nowicki       assert(SuspendPredecessorBlock && "malformed coro suspend instruction");
188*6292a808SJeremy Morse       InsertPoint = SuspendPredecessorBlock->getTerminator()->getIterator();
1892670565aSTyler Nowicki     }
1902670565aSTyler Nowicki 
1912670565aSTyler Nowicki     // Note: skip the first instruction as this is the actual use that we're
1922670565aSTyler Nowicki     // rematerializing everything for.
1932670565aSTyler Nowicki     auto I = RPOT.begin();
1942670565aSTyler Nowicki     ++I;
1952670565aSTyler Nowicki     for (; I != RPOT.end(); ++I) {
1962670565aSTyler Nowicki       Instruction *D = (*I)->Node;
1972670565aSTyler Nowicki       CurrentMaterialization = D->clone();
1982670565aSTyler Nowicki       CurrentMaterialization->setName(D->getName());
1992670565aSTyler Nowicki       CurrentMaterialization->insertBefore(InsertPoint);
200*6292a808SJeremy Morse       InsertPoint = CurrentMaterialization->getIterator();
2012670565aSTyler Nowicki 
2022670565aSTyler Nowicki       // Replace all uses of Def in the instructions being added as part of this
2032670565aSTyler Nowicki       // rematerialization group
2042670565aSTyler Nowicki       for (auto &I : InstructionsToProcess)
2052670565aSTyler Nowicki         I->replaceUsesOfWith(D, CurrentMaterialization);
2062670565aSTyler Nowicki 
2072670565aSTyler Nowicki       // Don't replace the final use at this point as this can cause problems
2082670565aSTyler Nowicki       // for other materializations. Instead, for any final use that uses a
2092670565aSTyler Nowicki       // define that's being rematerialized, record the replace values
2102670565aSTyler Nowicki       for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i)
2112670565aSTyler Nowicki         if (Use->getOperand(i) == D) // Is this operand pointing to oldval?
2122670565aSTyler Nowicki           FinalInstructionsToProcess.push_back(
2132670565aSTyler Nowicki               {Use, D, CurrentMaterialization});
2142670565aSTyler Nowicki 
2152670565aSTyler Nowicki       InstructionsToProcess.push_back(CurrentMaterialization);
2162670565aSTyler Nowicki     }
2172670565aSTyler Nowicki   }
2182670565aSTyler Nowicki 
2192670565aSTyler Nowicki   // Finally, replace the uses with the defines that we've just rematerialized
2202670565aSTyler Nowicki   for (auto &R : FinalInstructionsToProcess) {
2212670565aSTyler Nowicki     if (auto *PN = dyn_cast<PHINode>(R.Use)) {
2222670565aSTyler Nowicki       assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
2232670565aSTyler Nowicki                                                 "values in the PHINode");
2242670565aSTyler Nowicki       PN->replaceAllUsesWith(R.Remat);
2252670565aSTyler Nowicki       PN->eraseFromParent();
2262670565aSTyler Nowicki       continue;
2272670565aSTyler Nowicki     }
2282670565aSTyler Nowicki     R.Use->replaceUsesOfWith(R.Def, R.Remat);
2292670565aSTyler Nowicki   }
2302670565aSTyler Nowicki }
2312670565aSTyler Nowicki 
2322670565aSTyler Nowicki /// Default materializable callback
2332670565aSTyler Nowicki // Check for instructions that we can recreate on resume as opposed to spill
2342670565aSTyler Nowicki // the result into a coroutine frame.
2352670565aSTyler Nowicki bool llvm::coro::defaultMaterializable(Instruction &V) {
2362670565aSTyler Nowicki   return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
2372670565aSTyler Nowicki           isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
2382670565aSTyler Nowicki }
2392670565aSTyler Nowicki 
2402670565aSTyler Nowicki bool llvm::coro::isTriviallyMaterializable(Instruction &V) {
2412670565aSTyler Nowicki   return defaultMaterializable(V);
2422670565aSTyler Nowicki }
2432670565aSTyler Nowicki 
2442670565aSTyler Nowicki #ifndef NDEBUG
2452670565aSTyler Nowicki static void dumpRemats(
2462670565aSTyler Nowicki     StringRef Title,
2472670565aSTyler Nowicki     const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) {
2482670565aSTyler Nowicki   dbgs() << "------------- " << Title << "--------------\n";
2492670565aSTyler Nowicki   for (const auto &E : RM) {
2502670565aSTyler Nowicki     E.second->dump();
2512670565aSTyler Nowicki     dbgs() << "--\n";
2522670565aSTyler Nowicki   }
2532670565aSTyler Nowicki }
2542670565aSTyler Nowicki #endif
2552670565aSTyler Nowicki 
2562670565aSTyler Nowicki void coro::doRematerializations(
2572670565aSTyler Nowicki     Function &F, SuspendCrossingInfo &Checker,
2582670565aSTyler Nowicki     std::function<bool(Instruction &)> IsMaterializable) {
2592670565aSTyler Nowicki   if (F.hasOptNone())
2602670565aSTyler Nowicki     return;
2612670565aSTyler Nowicki 
2622670565aSTyler Nowicki   coro::SpillInfo Spills;
2632670565aSTyler Nowicki 
2642670565aSTyler Nowicki   // See if there are materializable instructions across suspend points
2652670565aSTyler Nowicki   // We record these as the starting point to also identify materializable
2662670565aSTyler Nowicki   // defs of uses in these operations
2672670565aSTyler Nowicki   for (Instruction &I : instructions(F)) {
2682670565aSTyler Nowicki     if (!IsMaterializable(I))
2692670565aSTyler Nowicki       continue;
2702670565aSTyler Nowicki     for (User *U : I.users())
2712670565aSTyler Nowicki       if (Checker.isDefinitionAcrossSuspend(I, U))
2722670565aSTyler Nowicki         Spills[&I].push_back(cast<Instruction>(U));
2732670565aSTyler Nowicki   }
2742670565aSTyler Nowicki 
2752670565aSTyler Nowicki   // Process each of the identified rematerializable instructions
2762670565aSTyler Nowicki   // and add predecessor instructions that can also be rematerialized.
2772670565aSTyler Nowicki   // This is actually a graph of instructions since we could potentially
2782670565aSTyler Nowicki   // have multiple uses of a def in the set of predecessor instructions.
2792670565aSTyler Nowicki   // The approach here is to maintain a graph of instructions for each bottom
2802670565aSTyler Nowicki   // level instruction - where we have a unique set of instructions (nodes)
2812670565aSTyler Nowicki   // and edges between them. We then walk the graph in reverse post-dominator
2822670565aSTyler Nowicki   // order to insert them past the suspend point, but ensure that ordering is
2832670565aSTyler Nowicki   // correct. We also rely on CSE removing duplicate defs for remats of
2842670565aSTyler Nowicki   // different instructions with a def in common (rather than maintaining more
2852670565aSTyler Nowicki   // complex graphs for each suspend point)
2862670565aSTyler Nowicki 
2872670565aSTyler Nowicki   // We can do this by adding new nodes to the list for each suspend
2882670565aSTyler Nowicki   // point. Then using standard GraphTraits to give a reverse post-order
2892670565aSTyler Nowicki   // traversal when we insert the nodes after the suspend
2902670565aSTyler Nowicki   SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats;
2912670565aSTyler Nowicki   for (auto &E : Spills) {
2922670565aSTyler Nowicki     for (Instruction *U : E.second) {
2932670565aSTyler Nowicki       // Don't process a user twice (this can happen if the instruction uses
2942670565aSTyler Nowicki       // more than one rematerializable def)
2952670565aSTyler Nowicki       if (AllRemats.count(U))
2962670565aSTyler Nowicki         continue;
2972670565aSTyler Nowicki 
2982670565aSTyler Nowicki       // Constructor creates the whole RematGraph for the given Use
2992670565aSTyler Nowicki       auto RematUPtr =
3002670565aSTyler Nowicki           std::make_unique<RematGraph>(IsMaterializable, U, Checker);
3012670565aSTyler Nowicki 
3022670565aSTyler Nowicki       LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
3032670565aSTyler Nowicki                  ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
3042670565aSTyler Nowicki                  for (auto I = RPOT.begin(); I != RPOT.end();
3052670565aSTyler Nowicki                       ++I) { (*I)->Node->dump(); } dbgs()
3062670565aSTyler Nowicki                  << "\n";);
3072670565aSTyler Nowicki 
3082670565aSTyler Nowicki       AllRemats[U] = std::move(RematUPtr);
3092670565aSTyler Nowicki     }
3102670565aSTyler Nowicki   }
3112670565aSTyler Nowicki 
3122670565aSTyler Nowicki   // Rewrite materializable instructions to be materialized at the use
3132670565aSTyler Nowicki   // point.
3142670565aSTyler Nowicki   LLVM_DEBUG(dumpRemats("Materializations", AllRemats));
3152670565aSTyler Nowicki   rewriteMaterializableInstructions(AllRemats);
3162670565aSTyler Nowicki }
317