xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (revision 304a99091c84f303ff5037dc6bf5455e4cfde7a1)
11ed65febSNathan Gauër //===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===//
21ed65febSNathan Gauër //
31ed65febSNathan Gauër // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41ed65febSNathan Gauër // See https://llvm.org/LICENSE.txt for license information.
51ed65febSNathan Gauër // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61ed65febSNathan Gauër //
71ed65febSNathan Gauër //===----------------------------------------------------------------------===//
81ed65febSNathan Gauër //
91ed65febSNathan Gauër //===----------------------------------------------------------------------===//
101ed65febSNathan Gauër 
111ed65febSNathan Gauër #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
121ed65febSNathan Gauër #include "SPIRV.h"
1310b1caf6Sjoaosaffran #include "SPIRVStructurizerWrapper.h"
141ed65febSNathan Gauër #include "SPIRVSubtarget.h"
151ed65febSNathan Gauër #include "SPIRVTargetMachine.h"
161ed65febSNathan Gauër #include "SPIRVUtils.h"
171ed65febSNathan Gauër #include "llvm/ADT/DenseMap.h"
181ed65febSNathan Gauër #include "llvm/ADT/SmallPtrSet.h"
191ed65febSNathan Gauër #include "llvm/Analysis/LoopInfo.h"
201ed65febSNathan Gauër #include "llvm/CodeGen/IntrinsicLowering.h"
211ed65febSNathan Gauër #include "llvm/IR/CFG.h"
221ed65febSNathan Gauër #include "llvm/IR/Dominators.h"
231ed65febSNathan Gauër #include "llvm/IR/IRBuilder.h"
241ed65febSNathan Gauër #include "llvm/IR/IntrinsicInst.h"
251ed65febSNathan Gauër #include "llvm/IR/Intrinsics.h"
261ed65febSNathan Gauër #include "llvm/IR/IntrinsicsSPIRV.h"
27380bb51bSjoaosaffran #include "llvm/IR/LegacyPassManager.h"
281ed65febSNathan Gauër #include "llvm/InitializePasses.h"
29380bb51bSjoaosaffran #include "llvm/PassRegistry.h"
30380bb51bSjoaosaffran #include "llvm/Transforms/Utils.h"
311ed65febSNathan Gauër #include "llvm/Transforms/Utils/Cloning.h"
321ed65febSNathan Gauër #include "llvm/Transforms/Utils/LoopSimplify.h"
331ed65febSNathan Gauër #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
341ed65febSNathan Gauër #include <queue>
351ed65febSNathan Gauër #include <stack>
361ed65febSNathan Gauër #include <unordered_set>
371ed65febSNathan Gauër 
381ed65febSNathan Gauër using namespace llvm;
391ed65febSNathan Gauër using namespace SPIRV;
401ed65febSNathan Gauër 
411ed65febSNathan Gauër namespace llvm {
421ed65febSNathan Gauër 
431ed65febSNathan Gauër void initializeSPIRVStructurizerPass(PassRegistry &);
441ed65febSNathan Gauër 
451ed65febSNathan Gauër namespace {
461ed65febSNathan Gauër 
471ed65febSNathan Gauër using BlockSet = std::unordered_set<BasicBlock *>;
481ed65febSNathan Gauër using Edge = std::pair<BasicBlock *, BasicBlock *>;
491ed65febSNathan Gauër 
501ed65febSNathan Gauër // Helper function to do a partial order visit from the block |Start|, calling
511ed65febSNathan Gauër // |Op| on each visited node.
521ed65febSNathan Gauër void partialOrderVisit(BasicBlock &Start,
531ed65febSNathan Gauër                        std::function<bool(BasicBlock *)> Op) {
541ed65febSNathan Gauër   PartialOrderingVisitor V(*Start.getParent());
551ed65febSNathan Gauër   V.partialOrderVisit(Start, Op);
561ed65febSNathan Gauër }
571ed65febSNathan Gauër 
581ed65febSNathan Gauër // Returns the exact convergence region in the tree defined by `Node` for which
591ed65febSNathan Gauër // `BB` is the header, nullptr otherwise.
601ed65febSNathan Gauër const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node,
611ed65febSNathan Gauër                                             BasicBlock *BB) {
621ed65febSNathan Gauër   if (Node->Entry == BB)
631ed65febSNathan Gauër     return Node;
641ed65febSNathan Gauër 
651ed65febSNathan Gauër   for (auto *Child : Node->Children) {
661ed65febSNathan Gauër     const auto *CR = getRegionForHeader(Child, BB);
671ed65febSNathan Gauër     if (CR != nullptr)
681ed65febSNathan Gauër       return CR;
691ed65febSNathan Gauër   }
701ed65febSNathan Gauër   return nullptr;
711ed65febSNathan Gauër }
721ed65febSNathan Gauër 
731ed65febSNathan Gauër // Returns the single BasicBlock exiting the convergence region `CR`,
741ed65febSNathan Gauër // nullptr if no such exit exists.
751ed65febSNathan Gauër BasicBlock *getExitFor(const ConvergenceRegion *CR) {
761ed65febSNathan Gauër   std::unordered_set<BasicBlock *> ExitTargets;
771ed65febSNathan Gauër   for (BasicBlock *Exit : CR->Exits) {
781ed65febSNathan Gauër     for (BasicBlock *Successor : successors(Exit)) {
791ed65febSNathan Gauër       if (CR->Blocks.count(Successor) == 0)
801ed65febSNathan Gauër         ExitTargets.insert(Successor);
811ed65febSNathan Gauër     }
821ed65febSNathan Gauër   }
831ed65febSNathan Gauër 
841ed65febSNathan Gauër   assert(ExitTargets.size() <= 1);
851ed65febSNathan Gauër   if (ExitTargets.size() == 0)
861ed65febSNathan Gauër     return nullptr;
871ed65febSNathan Gauër 
881ed65febSNathan Gauër   return *ExitTargets.begin();
891ed65febSNathan Gauër }
901ed65febSNathan Gauër 
911ed65febSNathan Gauër // Returns the merge block designated by I if I is a merge instruction, nullptr
921ed65febSNathan Gauër // otherwise.
931ed65febSNathan Gauër BasicBlock *getDesignatedMergeBlock(Instruction *I) {
94cba70550SNathan Gauër   IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I);
951ed65febSNathan Gauër   if (II == nullptr)
961ed65febSNathan Gauër     return nullptr;
971ed65febSNathan Gauër 
981ed65febSNathan Gauër   if (II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
991ed65febSNathan Gauër       II->getIntrinsicID() != Intrinsic::spv_selection_merge)
1001ed65febSNathan Gauër     return nullptr;
1011ed65febSNathan Gauër 
1021ed65febSNathan Gauër   BlockAddress *BA = cast<BlockAddress>(II->getOperand(0));
1031ed65febSNathan Gauër   return BA->getBasicBlock();
1041ed65febSNathan Gauër }
1051ed65febSNathan Gauër 
1061ed65febSNathan Gauër // Returns the continue block designated by I if I is an OpLoopMerge, nullptr
1071ed65febSNathan Gauër // otherwise.
1081ed65febSNathan Gauër BasicBlock *getDesignatedContinueBlock(Instruction *I) {
109cba70550SNathan Gauër   IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I);
1101ed65febSNathan Gauër   if (II == nullptr)
1111ed65febSNathan Gauër     return nullptr;
1121ed65febSNathan Gauër 
1131ed65febSNathan Gauër   if (II->getIntrinsicID() != Intrinsic::spv_loop_merge)
1141ed65febSNathan Gauër     return nullptr;
1151ed65febSNathan Gauër 
1161ed65febSNathan Gauër   BlockAddress *BA = cast<BlockAddress>(II->getOperand(1));
1171ed65febSNathan Gauër   return BA->getBasicBlock();
1181ed65febSNathan Gauër }
1191ed65febSNathan Gauër 
1201ed65febSNathan Gauër // Returns true if Header has one merge instruction which designated Merge as
1211ed65febSNathan Gauër // merge block.
1221ed65febSNathan Gauër bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) {
1231ed65febSNathan Gauër   for (auto &I : Header) {
1241ed65febSNathan Gauër     BasicBlock *MB = getDesignatedMergeBlock(&I);
1251ed65febSNathan Gauër     if (MB == &Merge)
1261ed65febSNathan Gauër       return true;
1271ed65febSNathan Gauër   }
1281ed65febSNathan Gauër   return false;
1291ed65febSNathan Gauër }
1301ed65febSNathan Gauër 
1311ed65febSNathan Gauër // Returns true if the BB has one OpLoopMerge instruction.
1321ed65febSNathan Gauër bool hasLoopMergeInstruction(BasicBlock &BB) {
1331ed65febSNathan Gauër   for (auto &I : BB)
1341ed65febSNathan Gauër     if (getDesignatedContinueBlock(&I))
1351ed65febSNathan Gauër       return true;
1361ed65febSNathan Gauër   return false;
1371ed65febSNathan Gauër }
1381ed65febSNathan Gauër 
1391ed65febSNathan Gauër // Returns true is I is an OpSelectionMerge or OpLoopMerge instruction, false
1401ed65febSNathan Gauër // otherwise.
1411ed65febSNathan Gauër bool isMergeInstruction(Instruction *I) {
1421ed65febSNathan Gauër   return getDesignatedMergeBlock(I) != nullptr;
1431ed65febSNathan Gauër }
1441ed65febSNathan Gauër 
1451ed65febSNathan Gauër // Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge
1461ed65febSNathan Gauër // instruction.
1471ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) {
1481ed65febSNathan Gauër   SmallPtrSet<BasicBlock *, 2> Output;
1491ed65febSNathan Gauër   for (BasicBlock &BB : F) {
1501ed65febSNathan Gauër     for (Instruction &I : BB) {
1511ed65febSNathan Gauër       if (getDesignatedMergeBlock(&I) != nullptr)
1521ed65febSNathan Gauër         Output.insert(&BB);
1531ed65febSNathan Gauër     }
1541ed65febSNathan Gauër   }
1551ed65febSNathan Gauër   return Output;
1561ed65febSNathan Gauër }
1571ed65febSNathan Gauër 
1581ed65febSNathan Gauër // Returns all basic blocks in |F| referenced by at least 1
1591ed65febSNathan Gauër // OpSelectionMerge/OpLoopMerge instruction.
1601ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) {
1611ed65febSNathan Gauër   SmallPtrSet<BasicBlock *, 2> Output;
1621ed65febSNathan Gauër   for (BasicBlock &BB : F) {
1631ed65febSNathan Gauër     for (Instruction &I : BB) {
1641ed65febSNathan Gauër       BasicBlock *MB = getDesignatedMergeBlock(&I);
1651ed65febSNathan Gauër       if (MB != nullptr)
1661ed65febSNathan Gauër         Output.insert(MB);
1671ed65febSNathan Gauër     }
1681ed65febSNathan Gauër   }
1691ed65febSNathan Gauër   return Output;
1701ed65febSNathan Gauër }
1711ed65febSNathan Gauër 
1721ed65febSNathan Gauër // Return all the merge instructions contained in BB.
1731ed65febSNathan Gauër // Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge
1741ed65febSNathan Gauër // instruction, but this can happen while we structurize the CFG.
1751ed65febSNathan Gauër std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) {
1761ed65febSNathan Gauër   std::vector<Instruction *> Output;
1771ed65febSNathan Gauër   for (Instruction &I : BB)
1781ed65febSNathan Gauër     if (isMergeInstruction(&I))
1791ed65febSNathan Gauër       Output.push_back(&I);
1801ed65febSNathan Gauër   return Output;
1811ed65febSNathan Gauër }
1821ed65febSNathan Gauër 
1831ed65febSNathan Gauër // Returns all basic blocks in |F| referenced as continue target by at least 1
1841ed65febSNathan Gauër // OpLoopMerge instruction.
1851ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) {
1861ed65febSNathan Gauër   SmallPtrSet<BasicBlock *, 2> Output;
1871ed65febSNathan Gauër   for (BasicBlock &BB : F) {
1881ed65febSNathan Gauër     for (Instruction &I : BB) {
1891ed65febSNathan Gauër       BasicBlock *MB = getDesignatedContinueBlock(&I);
1901ed65febSNathan Gauër       if (MB != nullptr)
1911ed65febSNathan Gauër         Output.insert(MB);
1921ed65febSNathan Gauër     }
1931ed65febSNathan Gauër   }
1941ed65febSNathan Gauër   return Output;
1951ed65febSNathan Gauër }
1961ed65febSNathan Gauër 
1971ed65febSNathan Gauër // Do a preorder traversal of the CFG starting from the BB |Start|.
1981ed65febSNathan Gauër // point. Calls |op| on each basic block encountered during the traversal.
1991ed65febSNathan Gauër void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) {
2001ed65febSNathan Gauër   std::stack<BasicBlock *> ToVisit;
2011ed65febSNathan Gauër   SmallPtrSet<BasicBlock *, 8> Seen;
2021ed65febSNathan Gauër 
2031ed65febSNathan Gauër   ToVisit.push(&Start);
2041ed65febSNathan Gauër   Seen.insert(ToVisit.top());
2051ed65febSNathan Gauër   while (ToVisit.size() != 0) {
2061ed65febSNathan Gauër     BasicBlock *BB = ToVisit.top();
2071ed65febSNathan Gauër     ToVisit.pop();
2081ed65febSNathan Gauër 
2091ed65febSNathan Gauër     if (!op(BB))
2101ed65febSNathan Gauër       continue;
2111ed65febSNathan Gauër 
2121ed65febSNathan Gauër     for (auto Succ : successors(BB)) {
2131ed65febSNathan Gauër       if (Seen.contains(Succ))
2141ed65febSNathan Gauër         continue;
2151ed65febSNathan Gauër       ToVisit.push(Succ);
2161ed65febSNathan Gauër       Seen.insert(Succ);
2171ed65febSNathan Gauër     }
2181ed65febSNathan Gauër   }
2191ed65febSNathan Gauër }
2201ed65febSNathan Gauër 
2211ed65febSNathan Gauër // Replaces the conditional and unconditional branch targets of |BB| by
2221ed65febSNathan Gauër // |NewTarget| if the target was |OldTarget|. This function also makes sure the
2231ed65febSNathan Gauër // associated merge instruction gets updated accordingly.
2241ed65febSNathan Gauër void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
2251ed65febSNathan Gauër                             BasicBlock *NewTarget) {
2261ed65febSNathan Gauër   auto *BI = cast<BranchInst>(BB->getTerminator());
2271ed65febSNathan Gauër 
2281ed65febSNathan Gauër   // 1. Replace all matching successors.
2291ed65febSNathan Gauër   for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
2301ed65febSNathan Gauër     if (BI->getSuccessor(i) == OldTarget)
2311ed65febSNathan Gauër       BI->setSuccessor(i, NewTarget);
2321ed65febSNathan Gauër   }
2331ed65febSNathan Gauër 
2341ed65febSNathan Gauër   // Branch was unconditional, no fixup required.
2351ed65febSNathan Gauër   if (BI->isUnconditional())
2361ed65febSNathan Gauër     return;
2371ed65febSNathan Gauër 
2381ed65febSNathan Gauër   // Branch had 2 successors, maybe now both are the same?
2391ed65febSNathan Gauër   if (BI->getSuccessor(0) != BI->getSuccessor(1))
2401ed65febSNathan Gauër     return;
2411ed65febSNathan Gauër 
2421ed65febSNathan Gauër   // Note: we may end up here because the original IR had such branches.
2431ed65febSNathan Gauër   // This means Target is not necessarily equal to NewTarget.
2441ed65febSNathan Gauër   IRBuilder<> Builder(BB);
2451ed65febSNathan Gauër   Builder.SetInsertPoint(BI);
2461ed65febSNathan Gauër   Builder.CreateBr(BI->getSuccessor(0));
2471ed65febSNathan Gauër   BI->eraseFromParent();
2481ed65febSNathan Gauër 
2491ed65febSNathan Gauër   // The branch was the only instruction, nothing else to do.
2501ed65febSNathan Gauër   if (BB->size() == 1)
2511ed65febSNathan Gauër     return;
2521ed65febSNathan Gauër 
2531ed65febSNathan Gauër   // Otherwise, we need to check: was there an OpSelectionMerge before this
2541ed65febSNathan Gauër   // branch? If we removed the OpBranchConditional, we must also remove the
2551ed65febSNathan Gauër   // OpSelectionMerge. This is not valid for OpLoopMerge:
2561ed65febSNathan Gauër   IntrinsicInst *II =
2571ed65febSNathan Gauër       dyn_cast<IntrinsicInst>(BB->getTerminator()->getPrevNode());
2581ed65febSNathan Gauër   if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge)
2591ed65febSNathan Gauër     return;
2601ed65febSNathan Gauër 
2611ed65febSNathan Gauër   Constant *C = cast<Constant>(II->getOperand(0));
2621ed65febSNathan Gauër   II->eraseFromParent();
2631ed65febSNathan Gauër   if (!C->isConstantUsed())
2641ed65febSNathan Gauër     C->destroyConstant();
2651ed65febSNathan Gauër }
2661ed65febSNathan Gauër 
2671ed65febSNathan Gauër // Replaces the target of branch instruction in |BB| with |NewTarget| if it
2681ed65febSNathan Gauër // was |OldTarget|. This function also fixes the associated merge instruction.
2691ed65febSNathan Gauër // Note: this function does not simplify branching instructions, it only updates
2701ed65febSNathan Gauër // targets. See also: simplifyBranches.
2711ed65febSNathan Gauër void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
2721ed65febSNathan Gauër                           BasicBlock *NewTarget) {
2731ed65febSNathan Gauër   auto *T = BB->getTerminator();
2741ed65febSNathan Gauër   if (isa<ReturnInst>(T))
2751ed65febSNathan Gauër     return;
2761ed65febSNathan Gauër 
2771ed65febSNathan Gauër   if (isa<BranchInst>(T))
2781ed65febSNathan Gauër     return replaceIfBranchTargets(BB, OldTarget, NewTarget);
2791ed65febSNathan Gauër 
2801ed65febSNathan Gauër   if (auto *SI = dyn_cast<SwitchInst>(T)) {
2811ed65febSNathan Gauër     for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
2821ed65febSNathan Gauër       if (SI->getSuccessor(i) == OldTarget)
2831ed65febSNathan Gauër         SI->setSuccessor(i, NewTarget);
2841ed65febSNathan Gauër     }
2851ed65febSNathan Gauër     return;
2861ed65febSNathan Gauër   }
2871ed65febSNathan Gauër 
2881ed65febSNathan Gauër   assert(false && "Unhandled terminator type.");
2891ed65febSNathan Gauër }
2901ed65febSNathan Gauër 
2911ed65febSNathan Gauër } // anonymous namespace
2921ed65febSNathan Gauër 
2931ed65febSNathan Gauër // Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
2941ed65febSNathan Gauër // adding merge instructions when required.
2951ed65febSNathan Gauër class SPIRVStructurizer : public FunctionPass {
2961ed65febSNathan Gauër 
2971ed65febSNathan Gauër   struct DivergentConstruct;
2981ed65febSNathan Gauër   // Represents a list of condition/loops/switch constructs.
2991ed65febSNathan Gauër   // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of
3001ed65febSNathan Gauër   // constructs.
3011ed65febSNathan Gauër   using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
3021ed65febSNathan Gauër 
3031ed65febSNathan Gauër   // Represents a divergent construct in the SPIR-V sense.
3041ed65febSNathan Gauër   // Such constructs are represented by a header (entry), a merge block (exit),
3051ed65febSNathan Gauër   // and possibly a continue block (back-edge). A construct can contain other
3061ed65febSNathan Gauër   // constructs, but their boundaries do not cross.
3071ed65febSNathan Gauër   struct DivergentConstruct {
3081ed65febSNathan Gauër     BasicBlock *Header = nullptr;
3091ed65febSNathan Gauër     BasicBlock *Merge = nullptr;
3101ed65febSNathan Gauër     BasicBlock *Continue = nullptr;
3111ed65febSNathan Gauër 
3121ed65febSNathan Gauër     DivergentConstruct *Parent = nullptr;
3131ed65febSNathan Gauër     ConstructList Children;
3141ed65febSNathan Gauër   };
3151ed65febSNathan Gauër 
3161ed65febSNathan Gauër   // An helper class to clean the construct boundaries.
3171ed65febSNathan Gauër   // It is used to gather the list of blocks that should belong to each
3181ed65febSNathan Gauër   // divergent construct, and possibly modify CFG edges when exits would cross
3191ed65febSNathan Gauër   // the boundary of multiple constructs.
3201ed65febSNathan Gauër   struct Splitter {
3211ed65febSNathan Gauër     Function &F;
3221ed65febSNathan Gauër     LoopInfo &LI;
3231ed65febSNathan Gauër     DomTreeBuilder::BBDomTree DT;
3241ed65febSNathan Gauër     DomTreeBuilder::BBPostDomTree PDT;
3251ed65febSNathan Gauër 
3261ed65febSNathan Gauër     Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
3271ed65febSNathan Gauër 
3281ed65febSNathan Gauër     void invalidate() {
3291ed65febSNathan Gauër       PDT.recalculate(F);
3301ed65febSNathan Gauër       DT.recalculate(F);
3311ed65febSNathan Gauër     }
3321ed65febSNathan Gauër 
3331ed65febSNathan Gauër     // Returns the list of blocks that belong to a SPIR-V loop construct,
3341ed65febSNathan Gauër     // including the continue construct.
3351ed65febSNathan Gauër     std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
3361ed65febSNathan Gauër                                                      BasicBlock *Merge) {
3371ed65febSNathan Gauër       assert(DT.dominates(Header, Merge));
3381ed65febSNathan Gauër       std::vector<BasicBlock *> Output;
3391ed65febSNathan Gauër       partialOrderVisit(*Header, [&](BasicBlock *BB) {
3401ed65febSNathan Gauër         if (BB == Merge)
3411ed65febSNathan Gauër           return false;
3421ed65febSNathan Gauër         if (DT.dominates(Merge, BB) || !DT.dominates(Header, BB))
3431ed65febSNathan Gauër           return false;
3441ed65febSNathan Gauër         Output.push_back(BB);
3451ed65febSNathan Gauër         return true;
3461ed65febSNathan Gauër       });
3471ed65febSNathan Gauër       return Output;
3481ed65febSNathan Gauër     }
3491ed65febSNathan Gauër 
3501ed65febSNathan Gauër     // Returns the list of blocks that belong to a SPIR-V selection construct.
3511ed65febSNathan Gauër     std::vector<BasicBlock *>
3521ed65febSNathan Gauër     getSelectionConstructBlocks(DivergentConstruct *Node) {
3531ed65febSNathan Gauër       assert(DT.dominates(Node->Header, Node->Merge));
3541ed65febSNathan Gauër       BlockSet OutsideBlocks;
3551ed65febSNathan Gauër       OutsideBlocks.insert(Node->Merge);
3561ed65febSNathan Gauër 
3571ed65febSNathan Gauër       for (DivergentConstruct *It = Node->Parent; It != nullptr;
3581ed65febSNathan Gauër            It = It->Parent) {
3591ed65febSNathan Gauër         OutsideBlocks.insert(It->Merge);
3601ed65febSNathan Gauër         if (It->Continue)
3611ed65febSNathan Gauër           OutsideBlocks.insert(It->Continue);
3621ed65febSNathan Gauër       }
3631ed65febSNathan Gauër 
3641ed65febSNathan Gauër       std::vector<BasicBlock *> Output;
3651ed65febSNathan Gauër       partialOrderVisit(*Node->Header, [&](BasicBlock *BB) {
3661ed65febSNathan Gauër         if (OutsideBlocks.count(BB) != 0)
3671ed65febSNathan Gauër           return false;
3681ed65febSNathan Gauër         if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
3691ed65febSNathan Gauër           return false;
3701ed65febSNathan Gauër         Output.push_back(BB);
3711ed65febSNathan Gauër         return true;
3721ed65febSNathan Gauër       });
3731ed65febSNathan Gauër       return Output;
3741ed65febSNathan Gauër     }
3751ed65febSNathan Gauër 
3761ed65febSNathan Gauër     // Returns the list of blocks that belong to a SPIR-V switch construct.
3771ed65febSNathan Gauër     std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
3781ed65febSNathan Gauër                                                        BasicBlock *Merge) {
3791ed65febSNathan Gauër       assert(DT.dominates(Header, Merge));
3801ed65febSNathan Gauër 
3811ed65febSNathan Gauër       std::vector<BasicBlock *> Output;
3821ed65febSNathan Gauër       partialOrderVisit(*Header, [&](BasicBlock *BB) {
3831ed65febSNathan Gauër         // the blocks structurally dominated by a switch header,
3841ed65febSNathan Gauër         if (!DT.dominates(Header, BB))
3851ed65febSNathan Gauër           return false;
3861ed65febSNathan Gauër         // excluding blocks structurally dominated by the switch header’s merge
3871ed65febSNathan Gauër         // block.
3881ed65febSNathan Gauër         if (DT.dominates(Merge, BB) || BB == Merge)
3891ed65febSNathan Gauër           return false;
3901ed65febSNathan Gauër         Output.push_back(BB);
3911ed65febSNathan Gauër         return true;
3921ed65febSNathan Gauër       });
3931ed65febSNathan Gauër       return Output;
3941ed65febSNathan Gauër     }
3951ed65febSNathan Gauër 
3961ed65febSNathan Gauër     // Returns the list of blocks that belong to a SPIR-V case construct.
3971ed65febSNathan Gauër     std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
3981ed65febSNathan Gauër                                                      BasicBlock *Merge) {
3991ed65febSNathan Gauër       assert(DT.dominates(Target, Merge));
4001ed65febSNathan Gauër 
4011ed65febSNathan Gauër       std::vector<BasicBlock *> Output;
4021ed65febSNathan Gauër       partialOrderVisit(*Target, [&](BasicBlock *BB) {
4031ed65febSNathan Gauër         // the blocks structurally dominated by an OpSwitch Target or Default
4041ed65febSNathan Gauër         // block
4051ed65febSNathan Gauër         if (!DT.dominates(Target, BB))
4061ed65febSNathan Gauër           return false;
4071ed65febSNathan Gauër         // excluding the blocks structurally dominated by the OpSwitch
4081ed65febSNathan Gauër         // construct’s corresponding merge block.
4091ed65febSNathan Gauër         if (DT.dominates(Merge, BB) || BB == Merge)
4101ed65febSNathan Gauër           return false;
4111ed65febSNathan Gauër         Output.push_back(BB);
4121ed65febSNathan Gauër         return true;
4131ed65febSNathan Gauër       });
4141ed65febSNathan Gauër       return Output;
4151ed65febSNathan Gauër     }
4161ed65febSNathan Gauër 
4171ed65febSNathan Gauër     // Splits the given edges by recreating proxy nodes so that the destination
418cba70550SNathan Gauër     // has unique incoming edges from this region.
4191ed65febSNathan Gauër     //
4201ed65febSNathan Gauër     // clang-format off
4211ed65febSNathan Gauër     //
4221ed65febSNathan Gauër     // In SPIR-V, constructs must have a single exit/merge.
4231ed65febSNathan Gauër     // Given nodes A and B in the construct, a node C outside, and the following edges.
4241ed65febSNathan Gauër     //  A -> C
4251ed65febSNathan Gauër     //  B -> C
4261ed65febSNathan Gauër     //
4271ed65febSNathan Gauër     // In such cases, we must create a new exit node D, that belong to the construct to make is viable:
4281ed65febSNathan Gauër     // A -> D -> C
4291ed65febSNathan Gauër     // B -> D -> C
4301ed65febSNathan Gauër     //
431cba70550SNathan Gauër     // This is fine (assuming C has no PHI nodes), but requires handling the merge instruction here.
432cba70550SNathan Gauër     // By adding a proxy node, we create a regular divergent shape which can easily be regularized later on.
4331ed65febSNathan Gauër     // A -> D -> D1 -> C
4341ed65febSNathan Gauër     // B -> D -> D2 -> C
4351ed65febSNathan Gauër     //
436cba70550SNathan Gauër     // A, B, D belongs to the construct. D is the exit. D1 and D2 are empty.
4371ed65febSNathan Gauër     //
4381ed65febSNathan Gauër     // clang-format on
4391ed65febSNathan Gauër     std::vector<Edge>
4401ed65febSNathan Gauër     createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
441cba70550SNathan Gauër       std::unordered_set<BasicBlock *> Seen;
4421ed65febSNathan Gauër       std::vector<Edge> Output;
4431ed65febSNathan Gauër       Output.reserve(Edges.size());
4441ed65febSNathan Gauër 
4451ed65febSNathan Gauër       for (auto &[Src, Dst] : Edges) {
446cba70550SNathan Gauër         auto [Iterator, Inserted] = Seen.insert(Src);
447cba70550SNathan Gauër         if (!Inserted) {
448cba70550SNathan Gauër           // Src already a source node. Cannot have 2 edges from A to B.
449cba70550SNathan Gauër           // Creating alias source block.
450cba70550SNathan Gauër           BasicBlock *NewSrc = BasicBlock::Create(
451cba70550SNathan Gauër               F.getContext(), Src->getName() + ".new.src", &F);
4521ed65febSNathan Gauër           replaceBranchTargets(Src, Dst, NewSrc);
4531ed65febSNathan Gauër           IRBuilder<> Builder(NewSrc);
4541ed65febSNathan Gauër           Builder.CreateBr(Dst);
455cba70550SNathan Gauër           Src = NewSrc;
456cba70550SNathan Gauër         }
4571ed65febSNathan Gauër 
458cba70550SNathan Gauër         Output.emplace_back(Src, Dst);
4591ed65febSNathan Gauër       }
4601ed65febSNathan Gauër 
4611ed65febSNathan Gauër       return Output;
4621ed65febSNathan Gauër     }
4631ed65febSNathan Gauër 
464cba70550SNathan Gauër     AllocaInst *CreateVariable(Function &F, Type *Type,
465cba70550SNathan Gauër                                BasicBlock::iterator Position) {
466cba70550SNathan Gauër       const DataLayout &DL = F.getDataLayout();
467cba70550SNathan Gauër       return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
468cba70550SNathan Gauër                             Position);
469cba70550SNathan Gauër     }
470cba70550SNathan Gauër 
4711ed65febSNathan Gauër     // Given a construct defined by |Header|, and a list of exiting edges
4721ed65febSNathan Gauër     // |Edges|, creates a new single exit node, fixing up those edges.
4731ed65febSNathan Gauër     BasicBlock *createSingleExitNode(BasicBlock *Header,
4741ed65febSNathan Gauër                                      std::vector<Edge> &Edges) {
475cba70550SNathan Gauër 
476cba70550SNathan Gauër       std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
4771ed65febSNathan Gauër 
4781ed65febSNathan Gauër       std::vector<BasicBlock *> Dsts;
4791ed65febSNathan Gauër       std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
480cba70550SNathan Gauër       auto NewExit = BasicBlock::Create(F.getContext(),
481cba70550SNathan Gauër                                         Header->getName() + ".new.exit", &F);
482cba70550SNathan Gauër       IRBuilder<> ExitBuilder(NewExit);
4831ed65febSNathan Gauër       for (auto &[Src, Dst] : FixedEdges) {
4841ed65febSNathan Gauër         if (DstToIndex.count(Dst) != 0)
4851ed65febSNathan Gauër           continue;
4861ed65febSNathan Gauër         DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
4871ed65febSNathan Gauër         Dsts.push_back(Dst);
4881ed65febSNathan Gauër       }
4891ed65febSNathan Gauër 
4901ed65febSNathan Gauër       if (Dsts.size() == 1) {
4911ed65febSNathan Gauër         for (auto &[Src, Dst] : FixedEdges) {
4921ed65febSNathan Gauër           replaceBranchTargets(Src, Dst, NewExit);
4931ed65febSNathan Gauër         }
4941ed65febSNathan Gauër         ExitBuilder.CreateBr(Dsts[0]);
4951ed65febSNathan Gauër         return NewExit;
4961ed65febSNathan Gauër       }
4971ed65febSNathan Gauër 
498cba70550SNathan Gauër       AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
499cba70550SNathan Gauër                                             F.begin()->getFirstInsertionPt());
5001ed65febSNathan Gauër       for (auto &[Src, Dst] : FixedEdges) {
501cba70550SNathan Gauër         IRBuilder<> B2(Src);
502cba70550SNathan Gauër         B2.SetInsertPoint(Src->getFirstInsertionPt());
503cba70550SNathan Gauër         B2.CreateStore(DstToIndex[Dst], Variable);
5041ed65febSNathan Gauër         replaceBranchTargets(Src, Dst, NewExit);
5051ed65febSNathan Gauër       }
5061ed65febSNathan Gauër 
507cba70550SNathan Gauër       llvm::Value *Load =
508cba70550SNathan Gauër           ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
509cba70550SNathan Gauër 
5101ed65febSNathan Gauër       // If we can avoid an OpSwitch, generate an OpBranch. Reason is some
5111ed65febSNathan Gauër       // OpBranch are allowed to exist without a new OpSelectionMerge if one of
5121ed65febSNathan Gauër       // the branch is the parent's merge node, while OpSwitches are not.
5131ed65febSNathan Gauër       if (Dsts.size() == 2) {
514cba70550SNathan Gauër         Value *Condition =
515cba70550SNathan Gauër             ExitBuilder.CreateCmp(CmpInst::ICMP_EQ, DstToIndex[Dsts[0]], Load);
5161ed65febSNathan Gauër         ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
5171ed65febSNathan Gauër         return NewExit;
5181ed65febSNathan Gauër       }
5191ed65febSNathan Gauër 
520cba70550SNathan Gauër       SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
5211ed65febSNathan Gauër       for (auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) {
5221ed65febSNathan Gauër         Sw->addCase(DstToIndex[*It], *It);
5231ed65febSNathan Gauër       }
5241ed65febSNathan Gauër       return NewExit;
5251ed65febSNathan Gauër     }
5261ed65febSNathan Gauër   };
5271ed65febSNathan Gauër 
5281ed65febSNathan Gauër   /// Create a value in BB set to the value associated with the branch the block
5291ed65febSNathan Gauër   /// terminator will take.
5301ed65febSNathan Gauër   Value *createExitVariable(
5311ed65febSNathan Gauër       BasicBlock *BB,
5321ed65febSNathan Gauër       const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
5331ed65febSNathan Gauër     auto *T = BB->getTerminator();
5341ed65febSNathan Gauër     if (isa<ReturnInst>(T))
5351ed65febSNathan Gauër       return nullptr;
5361ed65febSNathan Gauër 
5371ed65febSNathan Gauër     IRBuilder<> Builder(BB);
5381ed65febSNathan Gauër     Builder.SetInsertPoint(T);
5391ed65febSNathan Gauër 
5401ed65febSNathan Gauër     if (auto *BI = dyn_cast<BranchInst>(T)) {
5411ed65febSNathan Gauër 
5421ed65febSNathan Gauër       BasicBlock *LHSTarget = BI->getSuccessor(0);
5431ed65febSNathan Gauër       BasicBlock *RHSTarget =
5441ed65febSNathan Gauër           BI->isConditional() ? BI->getSuccessor(1) : nullptr;
5451ed65febSNathan Gauër 
5461ed65febSNathan Gauër       Value *LHS = TargetToValue.count(LHSTarget) != 0
5471ed65febSNathan Gauër                        ? TargetToValue.at(LHSTarget)
5481ed65febSNathan Gauër                        : nullptr;
5491ed65febSNathan Gauër       Value *RHS = TargetToValue.count(RHSTarget) != 0
5501ed65febSNathan Gauër                        ? TargetToValue.at(RHSTarget)
5511ed65febSNathan Gauër                        : nullptr;
5521ed65febSNathan Gauër 
5531ed65febSNathan Gauër       if (LHS == nullptr || RHS == nullptr)
5541ed65febSNathan Gauër         return LHS == nullptr ? RHS : LHS;
5551ed65febSNathan Gauër       return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
5561ed65febSNathan Gauër     }
5571ed65febSNathan Gauër 
5581ed65febSNathan Gauër     // TODO: add support for switch cases.
5591ed65febSNathan Gauër     llvm_unreachable("Unhandled terminator type.");
5601ed65febSNathan Gauër   }
5611ed65febSNathan Gauër 
5621ed65febSNathan Gauër   // Creates a new basic block in F with a single OpUnreachable instruction.
5631ed65febSNathan Gauër   BasicBlock *CreateUnreachable(Function &F) {
564cba70550SNathan Gauër     BasicBlock *BB = BasicBlock::Create(F.getContext(), "unreachable", &F);
5651ed65febSNathan Gauër     IRBuilder<> Builder(BB);
5661ed65febSNathan Gauër     Builder.CreateUnreachable();
5671ed65febSNathan Gauër     return BB;
5681ed65febSNathan Gauër   }
5691ed65febSNathan Gauër 
5701ed65febSNathan Gauër   // Add OpLoopMerge instruction on cycles.
5711ed65febSNathan Gauër   bool addMergeForLoops(Function &F) {
5721ed65febSNathan Gauër     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
5731ed65febSNathan Gauër     auto *TopLevelRegion =
5741ed65febSNathan Gauër         getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
5751ed65febSNathan Gauër             .getRegionInfo()
5761ed65febSNathan Gauër             .getTopLevelRegion();
5771ed65febSNathan Gauër 
5781ed65febSNathan Gauër     bool Modified = false;
5791ed65febSNathan Gauër     for (auto &BB : F) {
5801ed65febSNathan Gauër       // Not a loop header. Ignoring for now.
5811ed65febSNathan Gauër       if (!LI.isLoopHeader(&BB))
5821ed65febSNathan Gauër         continue;
5831ed65febSNathan Gauër       auto *L = LI.getLoopFor(&BB);
5841ed65febSNathan Gauër 
5851ed65febSNathan Gauër       // This loop header is not the entrance of a convergence region. Ignoring
5861ed65febSNathan Gauër       // this block.
5871ed65febSNathan Gauër       auto *CR = getRegionForHeader(TopLevelRegion, &BB);
5881ed65febSNathan Gauër       if (CR == nullptr)
5891ed65febSNathan Gauër         continue;
5901ed65febSNathan Gauër 
5911ed65febSNathan Gauër       IRBuilder<> Builder(&BB);
5921ed65febSNathan Gauër 
5931ed65febSNathan Gauër       auto *Merge = getExitFor(CR);
5941ed65febSNathan Gauër       // We are indeed in a loop, but there are no exits (infinite loop).
5951ed65febSNathan Gauër       // This could be caused by a bad shader, but also could be an artifact
5961ed65febSNathan Gauër       // from an earlier optimization. It is not always clear if structurally
5971ed65febSNathan Gauër       // reachable means runtime reachable, so we cannot error-out. What we must
5981ed65febSNathan Gauër       // do however is to make is legal on the SPIR-V point of view, hence
5991ed65febSNathan Gauër       // adding an unreachable merge block.
6001ed65febSNathan Gauër       if (Merge == nullptr) {
6011ed65febSNathan Gauër         BranchInst *Br = cast<BranchInst>(BB.getTerminator());
6021ed65febSNathan Gauër         assert(Br &&
6031ed65febSNathan Gauër                "This assumes the branch is not a switch. Maybe that's wrong?");
6041ed65febSNathan Gauër         assert(cast<BranchInst>(BB.getTerminator())->isUnconditional());
6051ed65febSNathan Gauër 
6061ed65febSNathan Gauër         Merge = CreateUnreachable(F);
6071ed65febSNathan Gauër         Builder.SetInsertPoint(Br);
6081ed65febSNathan Gauër         Builder.CreateCondBr(Builder.getFalse(), Merge, Br->getSuccessor(0));
6091ed65febSNathan Gauër         Br->eraseFromParent();
6101ed65febSNathan Gauër       }
6111ed65febSNathan Gauër 
6121ed65febSNathan Gauër       auto *Continue = L->getLoopLatch();
6131ed65febSNathan Gauër 
6141ed65febSNathan Gauër       Builder.SetInsertPoint(BB.getTerminator());
6151ed65febSNathan Gauër       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
6161ed65febSNathan Gauër       auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue);
6171ed65febSNathan Gauër       SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress};
6181ed65febSNathan Gauër 
6191ed65febSNathan Gauër       Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {}, {Args});
6201ed65febSNathan Gauër       Modified = true;
6211ed65febSNathan Gauër     }
6221ed65febSNathan Gauër 
6231ed65febSNathan Gauër     return Modified;
6241ed65febSNathan Gauër   }
6251ed65febSNathan Gauër 
6261ed65febSNathan Gauër   // Adds an OpSelectionMerge to the immediate dominator or each node with an
6271ed65febSNathan Gauër   // in-degree of 2 or more which is not already the merge target of an
6281ed65febSNathan Gauër   // OpLoopMerge/OpSelectionMerge.
6291ed65febSNathan Gauër   bool addMergeForNodesWithMultiplePredecessors(Function &F) {
6301ed65febSNathan Gauër     DomTreeBuilder::BBDomTree DT;
6311ed65febSNathan Gauër     DT.recalculate(F);
6321ed65febSNathan Gauër 
6331ed65febSNathan Gauër     bool Modified = false;
6341ed65febSNathan Gauër     for (auto &BB : F) {
6351ed65febSNathan Gauër       if (pred_size(&BB) <= 1)
6361ed65febSNathan Gauër         continue;
6371ed65febSNathan Gauër 
6381ed65febSNathan Gauër       if (hasLoopMergeInstruction(BB) && pred_size(&BB) <= 2)
6391ed65febSNathan Gauër         continue;
6401ed65febSNathan Gauër 
6411ed65febSNathan Gauër       assert(DT.getNode(&BB)->getIDom());
6421ed65febSNathan Gauër       BasicBlock *Header = DT.getNode(&BB)->getIDom()->getBlock();
6431ed65febSNathan Gauër 
6441ed65febSNathan Gauër       if (isDefinedAsSelectionMergeBy(*Header, BB))
6451ed65febSNathan Gauër         continue;
6461ed65febSNathan Gauër 
6471ed65febSNathan Gauër       IRBuilder<> Builder(Header);
6481ed65febSNathan Gauër       Builder.SetInsertPoint(Header->getTerminator());
6491ed65febSNathan Gauër 
6501ed65febSNathan Gauër       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
651380bb51bSjoaosaffran       createOpSelectMerge(&Builder, MergeAddress);
6521ed65febSNathan Gauër 
6531ed65febSNathan Gauër       Modified = true;
6541ed65febSNathan Gauër     }
6551ed65febSNathan Gauër 
6561ed65febSNathan Gauër     return Modified;
6571ed65febSNathan Gauër   }
6581ed65febSNathan Gauër 
6591ed65febSNathan Gauër   // When a block has multiple OpSelectionMerge/OpLoopMerge instructions, sorts
6601ed65febSNathan Gauër   // them to put the "largest" first. A merge instruction is defined as larger
6611ed65febSNathan Gauër   // than another when its target merge block post-dominates the other target's
6621ed65febSNathan Gauër   // merge block. (This ordering should match the nesting ordering of the source
6631ed65febSNathan Gauër   // HLSL).
6641ed65febSNathan Gauër   bool sortSelectionMerge(Function &F, BasicBlock &Block) {
6651ed65febSNathan Gauër     std::vector<Instruction *> MergeInstructions;
6661ed65febSNathan Gauër     for (Instruction &I : Block)
6671ed65febSNathan Gauër       if (isMergeInstruction(&I))
6681ed65febSNathan Gauër         MergeInstructions.push_back(&I);
6691ed65febSNathan Gauër 
6701ed65febSNathan Gauër     if (MergeInstructions.size() <= 1)
6711ed65febSNathan Gauër       return false;
6721ed65febSNathan Gauër 
6731ed65febSNathan Gauër     Instruction *InsertionPoint = *MergeInstructions.begin();
6741ed65febSNathan Gauër 
6751ed65febSNathan Gauër     PartialOrderingVisitor Visitor(F);
6761ed65febSNathan Gauër     std::sort(MergeInstructions.begin(), MergeInstructions.end(),
6771ed65febSNathan Gauër               [&Visitor](Instruction *Left, Instruction *Right) {
6781ed65febSNathan Gauër                 if (Left == Right)
6791ed65febSNathan Gauër                   return false;
6801ed65febSNathan Gauër                 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
6811ed65febSNathan Gauër                 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
6821ed65febSNathan Gauër                 return !Visitor.compare(RightMerge, LeftMerge);
6831ed65febSNathan Gauër               });
6841ed65febSNathan Gauër 
6851ed65febSNathan Gauër     for (Instruction *I : MergeInstructions) {
686*304a9909SJeremy Morse       I->moveBefore(InsertionPoint->getIterator());
6871ed65febSNathan Gauër       InsertionPoint = I;
6881ed65febSNathan Gauër     }
6891ed65febSNathan Gauër 
6901ed65febSNathan Gauër     return true;
6911ed65febSNathan Gauër   }
6921ed65febSNathan Gauër 
6931ed65febSNathan Gauër   // Sorts selection merge headers in |F|.
6941ed65febSNathan Gauër   // A is sorted before B if the merge block designated by B is an ancestor of
6951ed65febSNathan Gauër   // the one designated by A.
6961ed65febSNathan Gauër   bool sortSelectionMergeHeaders(Function &F) {
6971ed65febSNathan Gauër     bool Modified = false;
6981ed65febSNathan Gauër     for (BasicBlock &BB : F) {
6991ed65febSNathan Gauër       Modified |= sortSelectionMerge(F, BB);
7001ed65febSNathan Gauër     }
7011ed65febSNathan Gauër     return Modified;
7021ed65febSNathan Gauër   }
7031ed65febSNathan Gauër 
7041ed65febSNathan Gauër   // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge
7051ed65febSNathan Gauër   // instructions so each basic block contains only a single merge instruction.
7061ed65febSNathan Gauër   bool splitBlocksWithMultipleHeaders(Function &F) {
7071ed65febSNathan Gauër     std::stack<BasicBlock *> Work;
7081ed65febSNathan Gauër     for (auto &BB : F) {
7091ed65febSNathan Gauër       std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB);
7101ed65febSNathan Gauër       if (MergeInstructions.size() <= 1)
7111ed65febSNathan Gauër         continue;
7121ed65febSNathan Gauër       Work.push(&BB);
7131ed65febSNathan Gauër     }
7141ed65febSNathan Gauër 
7151ed65febSNathan Gauër     const bool Modified = Work.size() > 0;
7161ed65febSNathan Gauër     while (Work.size() > 0) {
7171ed65febSNathan Gauër       BasicBlock *Header = Work.top();
7181ed65febSNathan Gauër       Work.pop();
7191ed65febSNathan Gauër 
7201ed65febSNathan Gauër       std::vector<Instruction *> MergeInstructions =
7211ed65febSNathan Gauër           getMergeInstructions(*Header);
7221ed65febSNathan Gauër       for (unsigned i = 1; i < MergeInstructions.size(); i++) {
7231ed65febSNathan Gauër         BasicBlock *NewBlock =
7241ed65febSNathan Gauër             Header->splitBasicBlock(MergeInstructions[i], "new.header");
7251ed65febSNathan Gauër 
7261ed65febSNathan Gauër         if (getDesignatedContinueBlock(MergeInstructions[0]) == nullptr) {
7271ed65febSNathan Gauër           BasicBlock *Unreachable = CreateUnreachable(F);
7281ed65febSNathan Gauër 
7291ed65febSNathan Gauër           BranchInst *BI = cast<BranchInst>(Header->getTerminator());
7301ed65febSNathan Gauër           IRBuilder<> Builder(Header);
7311ed65febSNathan Gauër           Builder.SetInsertPoint(BI);
7321ed65febSNathan Gauër           Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
7331ed65febSNathan Gauër           BI->eraseFromParent();
7341ed65febSNathan Gauër         }
7351ed65febSNathan Gauër 
7361ed65febSNathan Gauër         Header = NewBlock;
7371ed65febSNathan Gauër       }
7381ed65febSNathan Gauër     }
7391ed65febSNathan Gauër 
7401ed65febSNathan Gauër     return Modified;
7411ed65febSNathan Gauër   }
7421ed65febSNathan Gauër 
7431ed65febSNathan Gauër   // Adds an OpSelectionMerge to each block with an out-degree >= 2 which
7441ed65febSNathan Gauër   // doesn't already have an OpSelectionMerge.
7451ed65febSNathan Gauër   bool addMergeForDivergentBlocks(Function &F) {
7461ed65febSNathan Gauër     DomTreeBuilder::BBPostDomTree PDT;
7471ed65febSNathan Gauër     PDT.recalculate(F);
7481ed65febSNathan Gauër     bool Modified = false;
7491ed65febSNathan Gauër 
7501ed65febSNathan Gauër     auto MergeBlocks = getMergeBlocks(F);
7511ed65febSNathan Gauër     auto ContinueBlocks = getContinueBlocks(F);
7521ed65febSNathan Gauër 
7531ed65febSNathan Gauër     for (auto &BB : F) {
7541ed65febSNathan Gauër       if (getMergeInstructions(BB).size() != 0)
7551ed65febSNathan Gauër         continue;
7561ed65febSNathan Gauër 
7571ed65febSNathan Gauër       std::vector<BasicBlock *> Candidates;
7581ed65febSNathan Gauër       for (BasicBlock *Successor : successors(&BB)) {
7591ed65febSNathan Gauër         if (MergeBlocks.contains(Successor))
7601ed65febSNathan Gauër           continue;
7611ed65febSNathan Gauër         if (ContinueBlocks.contains(Successor))
7621ed65febSNathan Gauër           continue;
7631ed65febSNathan Gauër         Candidates.push_back(Successor);
7641ed65febSNathan Gauër       }
7651ed65febSNathan Gauër 
7661ed65febSNathan Gauër       if (Candidates.size() <= 1)
7671ed65febSNathan Gauër         continue;
7681ed65febSNathan Gauër 
7691ed65febSNathan Gauër       Modified = true;
7701ed65febSNathan Gauër       BasicBlock *Merge = Candidates[0];
7711ed65febSNathan Gauër 
7721ed65febSNathan Gauër       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
7731ed65febSNathan Gauër       IRBuilder<> Builder(&BB);
7741ed65febSNathan Gauër       Builder.SetInsertPoint(BB.getTerminator());
775380bb51bSjoaosaffran       createOpSelectMerge(&Builder, MergeAddress);
7761ed65febSNathan Gauër     }
7771ed65febSNathan Gauër 
7781ed65febSNathan Gauër     return Modified;
7791ed65febSNathan Gauër   }
7801ed65febSNathan Gauër 
7811ed65febSNathan Gauër   // Gather all the exit nodes for the construct header by |Header| and
7821ed65febSNathan Gauër   // containing the blocks |Construct|.
7831ed65febSNathan Gauër   std::vector<Edge> getExitsFrom(const BlockSet &Construct,
7841ed65febSNathan Gauër                                  BasicBlock &Header) {
7851ed65febSNathan Gauër     std::vector<Edge> Output;
7861ed65febSNathan Gauër     visit(Header, [&](BasicBlock *Item) {
7871ed65febSNathan Gauër       if (Construct.count(Item) == 0)
7881ed65febSNathan Gauër         return false;
7891ed65febSNathan Gauër 
7901ed65febSNathan Gauër       for (BasicBlock *Successor : successors(Item)) {
7911ed65febSNathan Gauër         if (Construct.count(Successor) == 0)
7921ed65febSNathan Gauër           Output.emplace_back(Item, Successor);
7931ed65febSNathan Gauër       }
7941ed65febSNathan Gauër       return true;
7951ed65febSNathan Gauër     });
7961ed65febSNathan Gauër 
7971ed65febSNathan Gauër     return Output;
7981ed65febSNathan Gauër   }
7991ed65febSNathan Gauër 
8001ed65febSNathan Gauër   // Build a divergent construct tree searching from |BB|.
8011ed65febSNathan Gauër   // If |Parent| is not null, this tree is attached to the parent's tree.
8021ed65febSNathan Gauër   void constructDivergentConstruct(BlockSet &Visited, Splitter &S,
8031ed65febSNathan Gauër                                    BasicBlock *BB, DivergentConstruct *Parent) {
8041ed65febSNathan Gauër     if (Visited.count(BB) != 0)
8051ed65febSNathan Gauër       return;
8061ed65febSNathan Gauër     Visited.insert(BB);
8071ed65febSNathan Gauër 
8081ed65febSNathan Gauër     auto MIS = getMergeInstructions(*BB);
8091ed65febSNathan Gauër     if (MIS.size() == 0) {
8101ed65febSNathan Gauër       for (BasicBlock *Successor : successors(BB))
8111ed65febSNathan Gauër         constructDivergentConstruct(Visited, S, Successor, Parent);
8121ed65febSNathan Gauër       return;
8131ed65febSNathan Gauër     }
8141ed65febSNathan Gauër 
8151ed65febSNathan Gauër     assert(MIS.size() == 1);
8161ed65febSNathan Gauër     Instruction *MI = MIS[0];
8171ed65febSNathan Gauër 
8181ed65febSNathan Gauër     BasicBlock *Merge = getDesignatedMergeBlock(MI);
8191ed65febSNathan Gauër     BasicBlock *Continue = getDesignatedContinueBlock(MI);
8201ed65febSNathan Gauër 
8211ed65febSNathan Gauër     auto Output = std::make_unique<DivergentConstruct>();
8221ed65febSNathan Gauër     Output->Header = BB;
8231ed65febSNathan Gauër     Output->Merge = Merge;
8241ed65febSNathan Gauër     Output->Continue = Continue;
8251ed65febSNathan Gauër     Output->Parent = Parent;
8261ed65febSNathan Gauër 
8271ed65febSNathan Gauër     constructDivergentConstruct(Visited, S, Merge, Parent);
8281ed65febSNathan Gauër     if (Continue)
8291ed65febSNathan Gauër       constructDivergentConstruct(Visited, S, Continue, Output.get());
8301ed65febSNathan Gauër 
8311ed65febSNathan Gauër     for (BasicBlock *Successor : successors(BB))
8321ed65febSNathan Gauër       constructDivergentConstruct(Visited, S, Successor, Output.get());
8331ed65febSNathan Gauër 
8341ed65febSNathan Gauër     if (Parent)
8351ed65febSNathan Gauër       Parent->Children.emplace_back(std::move(Output));
8361ed65febSNathan Gauër   }
8371ed65febSNathan Gauër 
8381ed65febSNathan Gauër   // Returns the blocks belonging to the divergent construct |Node|.
8391ed65febSNathan Gauër   BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
8401ed65febSNathan Gauër     assert(Node->Header && Node->Merge);
8411ed65febSNathan Gauër 
8421ed65febSNathan Gauër     if (Node->Continue) {
8431ed65febSNathan Gauër       auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge);
8441ed65febSNathan Gauër       return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
8451ed65febSNathan Gauër     }
8461ed65febSNathan Gauër 
8471ed65febSNathan Gauër     auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
8481ed65febSNathan Gauër     return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
8491ed65febSNathan Gauër   }
8501ed65febSNathan Gauër 
8511ed65febSNathan Gauër   // Fixup the construct |Node| to respect a set of rules defined by the SPIR-V
8521ed65febSNathan Gauër   // spec.
8531ed65febSNathan Gauër   bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
8541ed65febSNathan Gauër     bool Modified = false;
8551ed65febSNathan Gauër     for (auto &Child : Node->Children)
8561ed65febSNathan Gauër       Modified |= fixupConstruct(S, Child.get());
8571ed65febSNathan Gauër 
8581ed65febSNathan Gauër     // This construct is the root construct. Does not represent any real
8591ed65febSNathan Gauër     // construct, just a way to access the first level of the forest.
8601ed65febSNathan Gauër     if (Node->Parent == nullptr)
8611ed65febSNathan Gauër       return Modified;
8621ed65febSNathan Gauër 
8631ed65febSNathan Gauër     // This node's parent is the root. Meaning this is a top-level construct.
8641ed65febSNathan Gauër     // There can be multiple exists, but all are guaranteed to exit at most 1
8651ed65febSNathan Gauër     // construct since we are at first level.
8661ed65febSNathan Gauër     if (Node->Parent->Header == nullptr)
8671ed65febSNathan Gauër       return Modified;
8681ed65febSNathan Gauër 
8691ed65febSNathan Gauër     // Health check for the structure.
8701ed65febSNathan Gauër     assert(Node->Header && Node->Merge);
8711ed65febSNathan Gauër     assert(Node->Parent->Header && Node->Parent->Merge);
8721ed65febSNathan Gauër 
8731ed65febSNathan Gauër     BlockSet ConstructBlocks = getConstructBlocks(S, Node);
8741ed65febSNathan Gauër     auto Edges = getExitsFrom(ConstructBlocks, *Node->Header);
8751ed65febSNathan Gauër 
8761ed65febSNathan Gauër     //  No edges exiting the construct.
8771ed65febSNathan Gauër     if (Edges.size() < 1)
8781ed65febSNathan Gauër       return Modified;
8791ed65febSNathan Gauër 
8801ed65febSNathan Gauër     bool HasBadEdge = Node->Merge == Node->Parent->Merge ||
8811ed65febSNathan Gauër                       Node->Merge == Node->Parent->Continue;
8821ed65febSNathan Gauër     // BasicBlock *Target = Edges[0].second;
8831ed65febSNathan Gauër     for (auto &[Src, Dst] : Edges) {
8841ed65febSNathan Gauër       // - Breaking from a selection construct: S is a selection construct, S is
8851ed65febSNathan Gauër       // the innermost structured
8861ed65febSNathan Gauër       //   control-flow construct containing A, and B is the merge block for S
8871ed65febSNathan Gauër       // - Breaking from the innermost loop: S is the innermost loop construct
8881ed65febSNathan Gauër       // containing A,
8891ed65febSNathan Gauër       //   and B is the merge block for S
8901ed65febSNathan Gauër       if (Node->Merge == Dst)
8911ed65febSNathan Gauër         continue;
8921ed65febSNathan Gauër 
8931ed65febSNathan Gauër       // Entering the innermost loop’s continue construct: S is the innermost
8941ed65febSNathan Gauër       // loop construct containing A, and B is the continue target for S
8951ed65febSNathan Gauër       if (Node->Continue == Dst)
8961ed65febSNathan Gauër         continue;
8971ed65febSNathan Gauër 
8981ed65febSNathan Gauër       // TODO: what about cases branching to another case in the switch? Seems
8991ed65febSNathan Gauër       // to work, but need to double check.
9001ed65febSNathan Gauër       HasBadEdge = true;
9011ed65febSNathan Gauër     }
9021ed65febSNathan Gauër 
9031ed65febSNathan Gauër     if (!HasBadEdge)
9041ed65febSNathan Gauër       return Modified;
9051ed65febSNathan Gauër 
9061ed65febSNathan Gauër     // Create a single exit node gathering all exit edges.
9071ed65febSNathan Gauër     BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges);
9081ed65febSNathan Gauër 
9091ed65febSNathan Gauër     // Fixup this construct's merge node to point to the new exit.
9101ed65febSNathan Gauër     // Note: this algorithm fixes inner-most divergence construct first. So
9111ed65febSNathan Gauër     // recursive structures sharing a single merge node are fixed from the
9121ed65febSNathan Gauër     // inside toward the outside.
9131ed65febSNathan Gauër     auto MergeInstructions = getMergeInstructions(*Node->Header);
9141ed65febSNathan Gauër     assert(MergeInstructions.size() == 1);
9151ed65febSNathan Gauër     Instruction *I = MergeInstructions[0];
9161ed65febSNathan Gauër     BlockAddress *BA = cast<BlockAddress>(I->getOperand(0));
9171ed65febSNathan Gauër     if (BA->getBasicBlock() == Node->Merge) {
9181ed65febSNathan Gauër       auto MergeAddress = BlockAddress::get(NewExit->getParent(), NewExit);
9191ed65febSNathan Gauër       I->setOperand(0, MergeAddress);
9201ed65febSNathan Gauër     }
9211ed65febSNathan Gauër 
9221ed65febSNathan Gauër     // Clean up of the possible dangling BockAddr operands to prevent MIR
9231ed65febSNathan Gauër     // comments about "address of removed block taken".
9241ed65febSNathan Gauër     if (!BA->isConstantUsed())
9251ed65febSNathan Gauër       BA->destroyConstant();
9261ed65febSNathan Gauër 
9271ed65febSNathan Gauër     Node->Merge = NewExit;
9281ed65febSNathan Gauër     // Regenerate the dom trees.
9291ed65febSNathan Gauër     S.invalidate();
9301ed65febSNathan Gauër     return true;
9311ed65febSNathan Gauër   }
9321ed65febSNathan Gauër 
9331ed65febSNathan Gauër   bool splitCriticalEdges(Function &F) {
9341ed65febSNathan Gauër     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
9351ed65febSNathan Gauër     Splitter S(F, LI);
9361ed65febSNathan Gauër 
9371ed65febSNathan Gauër     DivergentConstruct Root;
9381ed65febSNathan Gauër     BlockSet Visited;
9391ed65febSNathan Gauër     constructDivergentConstruct(Visited, S, &*F.begin(), &Root);
9401ed65febSNathan Gauër     return fixupConstruct(S, &Root);
9411ed65febSNathan Gauër   }
9421ed65febSNathan Gauër 
9431ed65febSNathan Gauër   // Simplify branches when possible:
9441ed65febSNathan Gauër   //  - if the 2 sides of a conditional branch are the same, transforms it to an
9451ed65febSNathan Gauër   //  unconditional branch.
9461ed65febSNathan Gauër   //  - if a switch has only 2 distinct successors, converts it to a conditional
9471ed65febSNathan Gauër   //  branch.
9481ed65febSNathan Gauër   bool simplifyBranches(Function &F) {
9491ed65febSNathan Gauër     bool Modified = false;
9501ed65febSNathan Gauër 
9511ed65febSNathan Gauër     for (BasicBlock &BB : F) {
9521ed65febSNathan Gauër       SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator());
9531ed65febSNathan Gauër       if (!SI)
9541ed65febSNathan Gauër         continue;
9551ed65febSNathan Gauër       if (SI->getNumCases() > 1)
9561ed65febSNathan Gauër         continue;
9571ed65febSNathan Gauër 
9581ed65febSNathan Gauër       Modified = true;
9591ed65febSNathan Gauër       IRBuilder<> Builder(&BB);
9601ed65febSNathan Gauër       Builder.SetInsertPoint(SI);
9611ed65febSNathan Gauër 
9621ed65febSNathan Gauër       if (SI->getNumCases() == 0) {
9631ed65febSNathan Gauër         Builder.CreateBr(SI->getDefaultDest());
9641ed65febSNathan Gauër       } else {
9651ed65febSNathan Gauër         Value *Condition =
9661ed65febSNathan Gauër             Builder.CreateCmp(CmpInst::ICMP_EQ, SI->getCondition(),
9671ed65febSNathan Gauër                               SI->case_begin()->getCaseValue());
9681ed65febSNathan Gauër         Builder.CreateCondBr(Condition, SI->case_begin()->getCaseSuccessor(),
9691ed65febSNathan Gauër                              SI->getDefaultDest());
9701ed65febSNathan Gauër       }
9711ed65febSNathan Gauër       SI->eraseFromParent();
9721ed65febSNathan Gauër     }
9731ed65febSNathan Gauër 
9741ed65febSNathan Gauër     return Modified;
9751ed65febSNathan Gauër   }
9761ed65febSNathan Gauër 
9771ed65febSNathan Gauër   // Makes sure every case target in |F| is unique. If 2 cases branch to the
9781ed65febSNathan Gauër   // same basic block, one of the targets is updated so it jumps to a new basic
9791ed65febSNathan Gauër   // block ending with a single unconditional branch to the original target.
9801ed65febSNathan Gauër   bool splitSwitchCases(Function &F) {
9811ed65febSNathan Gauër     bool Modified = false;
9821ed65febSNathan Gauër 
9831ed65febSNathan Gauër     for (BasicBlock &BB : F) {
9841ed65febSNathan Gauër       SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator());
9851ed65febSNathan Gauër       if (!SI)
9861ed65febSNathan Gauër         continue;
9871ed65febSNathan Gauër 
9881ed65febSNathan Gauër       BlockSet Seen;
9891ed65febSNathan Gauër       Seen.insert(SI->getDefaultDest());
9901ed65febSNathan Gauër 
9911ed65febSNathan Gauër       auto It = SI->case_begin();
9921ed65febSNathan Gauër       while (It != SI->case_end()) {
9931ed65febSNathan Gauër         BasicBlock *Target = It->getCaseSuccessor();
9941ed65febSNathan Gauër         if (Seen.count(Target) == 0) {
9951ed65febSNathan Gauër           Seen.insert(Target);
9961ed65febSNathan Gauër           ++It;
9971ed65febSNathan Gauër           continue;
9981ed65febSNathan Gauër         }
9991ed65febSNathan Gauër 
10001ed65febSNathan Gauër         Modified = true;
10011ed65febSNathan Gauër         BasicBlock *NewTarget =
10021ed65febSNathan Gauër             BasicBlock::Create(F.getContext(), "new.sw.case", &F);
10031ed65febSNathan Gauër         IRBuilder<> Builder(NewTarget);
10041ed65febSNathan Gauër         Builder.CreateBr(Target);
10051ed65febSNathan Gauër         SI->addCase(It->getCaseValue(), NewTarget);
10061ed65febSNathan Gauër         It = SI->removeCase(It);
10071ed65febSNathan Gauër       }
10081ed65febSNathan Gauër     }
10091ed65febSNathan Gauër 
10101ed65febSNathan Gauër     return Modified;
10111ed65febSNathan Gauër   }
10121ed65febSNathan Gauër 
1013cba70550SNathan Gauër   // Removes blocks not contributing to any structured CFG. This assumes there
1014cba70550SNathan Gauër   // is no PHI nodes.
10151ed65febSNathan Gauër   bool removeUselessBlocks(Function &F) {
10161ed65febSNathan Gauër     std::vector<BasicBlock *> ToRemove;
10171ed65febSNathan Gauër 
10181ed65febSNathan Gauër     auto MergeBlocks = getMergeBlocks(F);
10191ed65febSNathan Gauër     auto ContinueBlocks = getContinueBlocks(F);
10201ed65febSNathan Gauër 
10211ed65febSNathan Gauër     for (BasicBlock &BB : F) {
10221ed65febSNathan Gauër       if (BB.size() != 1)
10231ed65febSNathan Gauër         continue;
10241ed65febSNathan Gauër 
10251ed65febSNathan Gauër       if (isa<ReturnInst>(BB.getTerminator()))
10261ed65febSNathan Gauër         continue;
10271ed65febSNathan Gauër 
10281ed65febSNathan Gauër       if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
10291ed65febSNathan Gauër         continue;
10301ed65febSNathan Gauër 
10311ed65febSNathan Gauër       if (BB.getUniqueSuccessor() == nullptr)
10321ed65febSNathan Gauër         continue;
10331ed65febSNathan Gauër 
10341ed65febSNathan Gauër       BasicBlock *Successor = BB.getUniqueSuccessor();
10351ed65febSNathan Gauër       std::vector<BasicBlock *> Predecessors(predecessors(&BB).begin(),
10361ed65febSNathan Gauër                                              predecessors(&BB).end());
10371ed65febSNathan Gauër       for (BasicBlock *Predecessor : Predecessors)
10381ed65febSNathan Gauër         replaceBranchTargets(Predecessor, &BB, Successor);
10391ed65febSNathan Gauër       ToRemove.push_back(&BB);
10401ed65febSNathan Gauër     }
10411ed65febSNathan Gauër 
10421ed65febSNathan Gauër     for (BasicBlock *BB : ToRemove)
10431ed65febSNathan Gauër       BB->eraseFromParent();
10441ed65febSNathan Gauër 
10451ed65febSNathan Gauër     return ToRemove.size() != 0;
10461ed65febSNathan Gauër   }
10471ed65febSNathan Gauër 
10481ed65febSNathan Gauër   bool addHeaderToRemainingDivergentDAG(Function &F) {
10491ed65febSNathan Gauër     bool Modified = false;
10501ed65febSNathan Gauër 
10511ed65febSNathan Gauër     auto MergeBlocks = getMergeBlocks(F);
10521ed65febSNathan Gauër     auto ContinueBlocks = getContinueBlocks(F);
10531ed65febSNathan Gauër     auto HeaderBlocks = getHeaderBlocks(F);
10541ed65febSNathan Gauër 
10551ed65febSNathan Gauër     DomTreeBuilder::BBDomTree DT;
10561ed65febSNathan Gauër     DomTreeBuilder::BBPostDomTree PDT;
10571ed65febSNathan Gauër     PDT.recalculate(F);
10581ed65febSNathan Gauër     DT.recalculate(F);
10591ed65febSNathan Gauër 
10601ed65febSNathan Gauër     for (BasicBlock &BB : F) {
10611ed65febSNathan Gauër       if (HeaderBlocks.count(&BB) != 0)
10621ed65febSNathan Gauër         continue;
10631ed65febSNathan Gauër       if (succ_size(&BB) < 2)
10641ed65febSNathan Gauër         continue;
10651ed65febSNathan Gauër 
10661ed65febSNathan Gauër       size_t CandidateEdges = 0;
10671ed65febSNathan Gauër       for (BasicBlock *Successor : successors(&BB)) {
10681ed65febSNathan Gauër         if (MergeBlocks.count(Successor) != 0 ||
10691ed65febSNathan Gauër             ContinueBlocks.count(Successor) != 0)
10701ed65febSNathan Gauër           continue;
10711ed65febSNathan Gauër         if (HeaderBlocks.count(Successor) != 0)
10721ed65febSNathan Gauër           continue;
10731ed65febSNathan Gauër         CandidateEdges += 1;
10741ed65febSNathan Gauër       }
10751ed65febSNathan Gauër 
10761ed65febSNathan Gauër       if (CandidateEdges <= 1)
10771ed65febSNathan Gauër         continue;
10781ed65febSNathan Gauër 
10791ed65febSNathan Gauër       BasicBlock *Header = &BB;
10801ed65febSNathan Gauër       BasicBlock *Merge = PDT.getNode(&BB)->getIDom()->getBlock();
10811ed65febSNathan Gauër 
10821ed65febSNathan Gauër       bool HasBadBlock = false;
10831ed65febSNathan Gauër       visit(*Header, [&](const BasicBlock *Node) {
10841ed65febSNathan Gauër         if (DT.dominates(Header, Node))
10851ed65febSNathan Gauër           return false;
10861ed65febSNathan Gauër         if (PDT.dominates(Merge, Node))
10871ed65febSNathan Gauër           return false;
10881ed65febSNathan Gauër         if (Node == Header || Node == Merge)
10891ed65febSNathan Gauër           return true;
10901ed65febSNathan Gauër 
10911ed65febSNathan Gauër         HasBadBlock |= MergeBlocks.count(Node) != 0 ||
10921ed65febSNathan Gauër                        ContinueBlocks.count(Node) != 0 ||
10931ed65febSNathan Gauër                        HeaderBlocks.count(Node) != 0;
10941ed65febSNathan Gauër         return !HasBadBlock;
10951ed65febSNathan Gauër       });
10961ed65febSNathan Gauër 
10971ed65febSNathan Gauër       if (HasBadBlock)
10981ed65febSNathan Gauër         continue;
10991ed65febSNathan Gauër 
11001ed65febSNathan Gauër       Modified = true;
1101cba70550SNathan Gauër 
1102cba70550SNathan Gauër       if (Merge == nullptr) {
1103cba70550SNathan Gauër         Merge = *successors(Header).begin();
1104cba70550SNathan Gauër         IRBuilder<> Builder(Header);
1105cba70550SNathan Gauër         Builder.SetInsertPoint(Header->getTerminator());
1106cba70550SNathan Gauër 
1107cba70550SNathan Gauër         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
1108380bb51bSjoaosaffran         createOpSelectMerge(&Builder, MergeAddress);
1109cba70550SNathan Gauër         continue;
1110cba70550SNathan Gauër       }
1111cba70550SNathan Gauër 
11121ed65febSNathan Gauër       Instruction *SplitInstruction = Merge->getTerminator();
11131ed65febSNathan Gauër       if (isMergeInstruction(SplitInstruction->getPrevNode()))
11141ed65febSNathan Gauër         SplitInstruction = SplitInstruction->getPrevNode();
11151ed65febSNathan Gauër       BasicBlock *NewMerge =
11161ed65febSNathan Gauër           Merge->splitBasicBlockBefore(SplitInstruction, "new.merge");
11171ed65febSNathan Gauër 
11181ed65febSNathan Gauër       IRBuilder<> Builder(Header);
11191ed65febSNathan Gauër       Builder.SetInsertPoint(Header->getTerminator());
11201ed65febSNathan Gauër 
11211ed65febSNathan Gauër       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
1122380bb51bSjoaosaffran       createOpSelectMerge(&Builder, MergeAddress);
11231ed65febSNathan Gauër     }
11241ed65febSNathan Gauër 
11251ed65febSNathan Gauër     return Modified;
11261ed65febSNathan Gauër   }
11271ed65febSNathan Gauër 
11281ed65febSNathan Gauër public:
11291ed65febSNathan Gauër   static char ID;
11301ed65febSNathan Gauër 
11311ed65febSNathan Gauër   SPIRVStructurizer() : FunctionPass(ID) {
11321ed65febSNathan Gauër     initializeSPIRVStructurizerPass(*PassRegistry::getPassRegistry());
11331ed65febSNathan Gauër   };
11341ed65febSNathan Gauër 
11351ed65febSNathan Gauër   virtual bool runOnFunction(Function &F) override {
11361ed65febSNathan Gauër     bool Modified = false;
11371ed65febSNathan Gauër 
11381ed65febSNathan Gauër     // In LLVM, Switches are allowed to have several cases branching to the same
11391ed65febSNathan Gauër     // basic block. This is allowed in SPIR-V, but can make structurizing SPIR-V
11401ed65febSNathan Gauër     // harder, so first remove edge cases.
11411ed65febSNathan Gauër     Modified |= splitSwitchCases(F);
11421ed65febSNathan Gauër 
11431ed65febSNathan Gauër     // LLVM allows conditional branches to have both side jumping to the same
11441ed65febSNathan Gauër     // block. It also allows switched to have a single default, or just one
11451ed65febSNathan Gauër     // case. Cleaning this up now.
11461ed65febSNathan Gauër     Modified |= simplifyBranches(F);
11471ed65febSNathan Gauër 
11481ed65febSNathan Gauër     // At this state, we should have a reducible CFG with cycles.
11491ed65febSNathan Gauër     // STEP 1: Adding OpLoopMerge instructions to loop headers.
11501ed65febSNathan Gauër     Modified |= addMergeForLoops(F);
11511ed65febSNathan Gauër 
11521ed65febSNathan Gauër     // STEP 2: adding OpSelectionMerge to each node with an in-degree >= 2.
11531ed65febSNathan Gauër     Modified |= addMergeForNodesWithMultiplePredecessors(F);
11541ed65febSNathan Gauër 
11551ed65febSNathan Gauër     // STEP 3:
11561ed65febSNathan Gauër     // Sort selection merge, the largest construct goes first.
11571ed65febSNathan Gauër     // This simplifies the next step.
11581ed65febSNathan Gauër     Modified |= sortSelectionMergeHeaders(F);
11591ed65febSNathan Gauër 
11601ed65febSNathan Gauër     // STEP 4: As this stage, we can have a single basic block with multiple
11611ed65febSNathan Gauër     // OpLoopMerge/OpSelectionMerge instructions. Splitting this block so each
11621ed65febSNathan Gauër     // BB has a single merge instruction.
11631ed65febSNathan Gauër     Modified |= splitBlocksWithMultipleHeaders(F);
11641ed65febSNathan Gauër 
11651ed65febSNathan Gauër     // STEP 5: In the previous steps, we added merge blocks the loops and
11661ed65febSNathan Gauër     // natural merge blocks (in-degree >= 2). What remains are conditions with
11671ed65febSNathan Gauër     // an exiting branch (return, unreachable). In such case, we must start from
11681ed65febSNathan Gauër     // the header, and add headers to divergent construct with no headers.
11691ed65febSNathan Gauër     Modified |= addMergeForDivergentBlocks(F);
11701ed65febSNathan Gauër 
11711ed65febSNathan Gauër     // STEP 6: At this stage, we have several divergent construct defines by a
11721ed65febSNathan Gauër     // header and a merge block. But their boundaries have no constraints: a
11731ed65febSNathan Gauër     // construct exit could be outside of the parents' construct exit. Such
11741ed65febSNathan Gauër     // edges are called critical edges. What we need is to split those edges
11751ed65febSNathan Gauër     // into several parts. Each part exiting the parent's construct by its merge
11761ed65febSNathan Gauër     // block.
11771ed65febSNathan Gauër     Modified |= splitCriticalEdges(F);
11781ed65febSNathan Gauër 
11791ed65febSNathan Gauër     // STEP 7: The previous steps possibly created a lot of "proxy" blocks.
11801ed65febSNathan Gauër     // Blocks with a single unconditional branch, used to create a valid
11811ed65febSNathan Gauër     // divergent construct tree. Some nodes are still requires (e.g: nodes
11821ed65febSNathan Gauër     // allowing a valid exit through the parent's merge block). But some are
11831ed65febSNathan Gauër     // left-overs of past transformations, and could cause actual validation
11841ed65febSNathan Gauër     // issues. E.g: the SPIR-V spec allows a construct to break to the parents
11851ed65febSNathan Gauër     // loop construct without an OpSelectionMerge, but this requires a straight
11861ed65febSNathan Gauër     // jump. If a proxy block lies between the conditional branch and the
11871ed65febSNathan Gauër     // parent's merge, the CFG is not valid.
11881ed65febSNathan Gauër     Modified |= removeUselessBlocks(F);
11891ed65febSNathan Gauër 
11901ed65febSNathan Gauër     // STEP 8: Final fix-up steps: our tree boundaries are correct, but some
11911ed65febSNathan Gauër     // blocks are branching with no header. Those are often simple conditional
11921ed65febSNathan Gauër     // branches with 1 or 2 returning edges. Adding a header for those.
11931ed65febSNathan Gauër     Modified |= addHeaderToRemainingDivergentDAG(F);
11941ed65febSNathan Gauër 
11951ed65febSNathan Gauër     // STEP 9: sort basic blocks to match both the LLVM & SPIR-V requirements.
11961ed65febSNathan Gauër     Modified |= sortBlocks(F);
11971ed65febSNathan Gauër 
11981ed65febSNathan Gauër     return Modified;
11991ed65febSNathan Gauër   }
12001ed65febSNathan Gauër 
12011ed65febSNathan Gauër   void getAnalysisUsage(AnalysisUsage &AU) const override {
12021ed65febSNathan Gauër     AU.addRequired<DominatorTreeWrapperPass>();
12031ed65febSNathan Gauër     AU.addRequired<LoopInfoWrapperPass>();
12041ed65febSNathan Gauër     AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
12051ed65febSNathan Gauër 
12061ed65febSNathan Gauër     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
12071ed65febSNathan Gauër     FunctionPass::getAnalysisUsage(AU);
12081ed65febSNathan Gauër   }
1209380bb51bSjoaosaffran 
1210380bb51bSjoaosaffran   void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1211380bb51bSjoaosaffran     Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
1212380bb51bSjoaosaffran 
1213380bb51bSjoaosaffran     MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
1214380bb51bSjoaosaffran 
1215380bb51bSjoaosaffran     ConstantInt *BranchHint = llvm::ConstantInt::get(Builder->getInt32Ty(), 0);
1216380bb51bSjoaosaffran 
1217380bb51bSjoaosaffran     if (MDNode) {
1218380bb51bSjoaosaffran       assert(MDNode->getNumOperands() == 2 &&
1219380bb51bSjoaosaffran              "invalid metadata hlsl.controlflow.hint");
1220380bb51bSjoaosaffran       BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1));
1221380bb51bSjoaosaffran 
1222380bb51bSjoaosaffran       assert(BranchHint && "invalid metadata value for hlsl.controlflow.hint");
1223380bb51bSjoaosaffran     }
1224380bb51bSjoaosaffran 
1225380bb51bSjoaosaffran     llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, BranchHint};
1226380bb51bSjoaosaffran 
1227380bb51bSjoaosaffran     Builder->CreateIntrinsic(Intrinsic::spv_selection_merge,
1228380bb51bSjoaosaffran                              {MergeAddress->getType()}, {Args});
1229380bb51bSjoaosaffran   }
12301ed65febSNathan Gauër };
12311ed65febSNathan Gauër } // namespace llvm
12321ed65febSNathan Gauër 
12331ed65febSNathan Gauër char SPIRVStructurizer::ID = 0;
12341ed65febSNathan Gauër 
123510b1caf6Sjoaosaffran INITIALIZE_PASS_BEGIN(SPIRVStructurizer, "spirv-structurizer",
123610b1caf6Sjoaosaffran                       "structurize SPIRV", false, false)
12371ed65febSNathan Gauër INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
12381ed65febSNathan Gauër INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
12391ed65febSNathan Gauër INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
12401ed65febSNathan Gauër INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
12411ed65febSNathan Gauër 
124210b1caf6Sjoaosaffran INITIALIZE_PASS_END(SPIRVStructurizer, "spirv-structurizer",
124310b1caf6Sjoaosaffran                     "structurize SPIRV", false, false)
12441ed65febSNathan Gauër 
12451ed65febSNathan Gauër FunctionPass *llvm::createSPIRVStructurizerPass() {
12461ed65febSNathan Gauër   return new SPIRVStructurizer();
12471ed65febSNathan Gauër }
124810b1caf6Sjoaosaffran 
124910b1caf6Sjoaosaffran PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F,
125010b1caf6Sjoaosaffran                                                 FunctionAnalysisManager &AF) {
1251380bb51bSjoaosaffran 
1252380bb51bSjoaosaffran   auto FPM = legacy::FunctionPassManager(F.getParent());
1253380bb51bSjoaosaffran   FPM.add(createSPIRVStructurizerPass());
1254380bb51bSjoaosaffran 
1255380bb51bSjoaosaffran   if (!FPM.run(F))
125610b1caf6Sjoaosaffran     return PreservedAnalyses::all();
125710b1caf6Sjoaosaffran   PreservedAnalyses PA;
125810b1caf6Sjoaosaffran   PA.preserveSet<CFGAnalyses>();
125910b1caf6Sjoaosaffran   return PA;
126010b1caf6Sjoaosaffran }
1261