11ed65febSNathan Gauër //===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===// 21ed65febSNathan Gauër // 31ed65febSNathan Gauër // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 41ed65febSNathan Gauër // See https://llvm.org/LICENSE.txt for license information. 51ed65febSNathan Gauër // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61ed65febSNathan Gauër // 71ed65febSNathan Gauër //===----------------------------------------------------------------------===// 81ed65febSNathan Gauër // 91ed65febSNathan Gauër //===----------------------------------------------------------------------===// 101ed65febSNathan Gauër 111ed65febSNathan Gauër #include "Analysis/SPIRVConvergenceRegionAnalysis.h" 121ed65febSNathan Gauër #include "SPIRV.h" 1310b1caf6Sjoaosaffran #include "SPIRVStructurizerWrapper.h" 141ed65febSNathan Gauër #include "SPIRVSubtarget.h" 151ed65febSNathan Gauër #include "SPIRVTargetMachine.h" 161ed65febSNathan Gauër #include "SPIRVUtils.h" 171ed65febSNathan Gauër #include "llvm/ADT/DenseMap.h" 181ed65febSNathan Gauër #include "llvm/ADT/SmallPtrSet.h" 191ed65febSNathan Gauër #include "llvm/Analysis/LoopInfo.h" 201ed65febSNathan Gauër #include "llvm/CodeGen/IntrinsicLowering.h" 211ed65febSNathan Gauër #include "llvm/IR/CFG.h" 221ed65febSNathan Gauër #include "llvm/IR/Dominators.h" 231ed65febSNathan Gauër #include "llvm/IR/IRBuilder.h" 241ed65febSNathan Gauër #include "llvm/IR/IntrinsicInst.h" 251ed65febSNathan Gauër #include "llvm/IR/Intrinsics.h" 261ed65febSNathan Gauër #include "llvm/IR/IntrinsicsSPIRV.h" 27380bb51bSjoaosaffran #include "llvm/IR/LegacyPassManager.h" 281ed65febSNathan Gauër #include "llvm/InitializePasses.h" 29380bb51bSjoaosaffran #include "llvm/PassRegistry.h" 30380bb51bSjoaosaffran #include "llvm/Transforms/Utils.h" 311ed65febSNathan Gauër #include "llvm/Transforms/Utils/Cloning.h" 321ed65febSNathan Gauër #include "llvm/Transforms/Utils/LoopSimplify.h" 331ed65febSNathan Gauër #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 341ed65febSNathan Gauër #include <queue> 351ed65febSNathan Gauër #include <stack> 361ed65febSNathan Gauër #include <unordered_set> 371ed65febSNathan Gauër 381ed65febSNathan Gauër using namespace llvm; 391ed65febSNathan Gauër using namespace SPIRV; 401ed65febSNathan Gauër 411ed65febSNathan Gauër namespace llvm { 421ed65febSNathan Gauër 431ed65febSNathan Gauër void initializeSPIRVStructurizerPass(PassRegistry &); 441ed65febSNathan Gauër 451ed65febSNathan Gauër namespace { 461ed65febSNathan Gauër 471ed65febSNathan Gauër using BlockSet = std::unordered_set<BasicBlock *>; 481ed65febSNathan Gauër using Edge = std::pair<BasicBlock *, BasicBlock *>; 491ed65febSNathan Gauër 501ed65febSNathan Gauër // Helper function to do a partial order visit from the block |Start|, calling 511ed65febSNathan Gauër // |Op| on each visited node. 521ed65febSNathan Gauër void partialOrderVisit(BasicBlock &Start, 531ed65febSNathan Gauër std::function<bool(BasicBlock *)> Op) { 541ed65febSNathan Gauër PartialOrderingVisitor V(*Start.getParent()); 551ed65febSNathan Gauër V.partialOrderVisit(Start, Op); 561ed65febSNathan Gauër } 571ed65febSNathan Gauër 581ed65febSNathan Gauër // Returns the exact convergence region in the tree defined by `Node` for which 591ed65febSNathan Gauër // `BB` is the header, nullptr otherwise. 601ed65febSNathan Gauër const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node, 611ed65febSNathan Gauër BasicBlock *BB) { 621ed65febSNathan Gauër if (Node->Entry == BB) 631ed65febSNathan Gauër return Node; 641ed65febSNathan Gauër 651ed65febSNathan Gauër for (auto *Child : Node->Children) { 661ed65febSNathan Gauër const auto *CR = getRegionForHeader(Child, BB); 671ed65febSNathan Gauër if (CR != nullptr) 681ed65febSNathan Gauër return CR; 691ed65febSNathan Gauër } 701ed65febSNathan Gauër return nullptr; 711ed65febSNathan Gauër } 721ed65febSNathan Gauër 731ed65febSNathan Gauër // Returns the single BasicBlock exiting the convergence region `CR`, 741ed65febSNathan Gauër // nullptr if no such exit exists. 751ed65febSNathan Gauër BasicBlock *getExitFor(const ConvergenceRegion *CR) { 761ed65febSNathan Gauër std::unordered_set<BasicBlock *> ExitTargets; 771ed65febSNathan Gauër for (BasicBlock *Exit : CR->Exits) { 781ed65febSNathan Gauër for (BasicBlock *Successor : successors(Exit)) { 791ed65febSNathan Gauër if (CR->Blocks.count(Successor) == 0) 801ed65febSNathan Gauër ExitTargets.insert(Successor); 811ed65febSNathan Gauër } 821ed65febSNathan Gauër } 831ed65febSNathan Gauër 841ed65febSNathan Gauër assert(ExitTargets.size() <= 1); 851ed65febSNathan Gauër if (ExitTargets.size() == 0) 861ed65febSNathan Gauër return nullptr; 871ed65febSNathan Gauër 881ed65febSNathan Gauër return *ExitTargets.begin(); 891ed65febSNathan Gauër } 901ed65febSNathan Gauër 911ed65febSNathan Gauër // Returns the merge block designated by I if I is a merge instruction, nullptr 921ed65febSNathan Gauër // otherwise. 931ed65febSNathan Gauër BasicBlock *getDesignatedMergeBlock(Instruction *I) { 94cba70550SNathan Gauër IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I); 951ed65febSNathan Gauër if (II == nullptr) 961ed65febSNathan Gauër return nullptr; 971ed65febSNathan Gauër 981ed65febSNathan Gauër if (II->getIntrinsicID() != Intrinsic::spv_loop_merge && 991ed65febSNathan Gauër II->getIntrinsicID() != Intrinsic::spv_selection_merge) 1001ed65febSNathan Gauër return nullptr; 1011ed65febSNathan Gauër 1021ed65febSNathan Gauër BlockAddress *BA = cast<BlockAddress>(II->getOperand(0)); 1031ed65febSNathan Gauër return BA->getBasicBlock(); 1041ed65febSNathan Gauër } 1051ed65febSNathan Gauër 1061ed65febSNathan Gauër // Returns the continue block designated by I if I is an OpLoopMerge, nullptr 1071ed65febSNathan Gauër // otherwise. 1081ed65febSNathan Gauër BasicBlock *getDesignatedContinueBlock(Instruction *I) { 109cba70550SNathan Gauër IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I); 1101ed65febSNathan Gauër if (II == nullptr) 1111ed65febSNathan Gauër return nullptr; 1121ed65febSNathan Gauër 1131ed65febSNathan Gauër if (II->getIntrinsicID() != Intrinsic::spv_loop_merge) 1141ed65febSNathan Gauër return nullptr; 1151ed65febSNathan Gauër 1161ed65febSNathan Gauër BlockAddress *BA = cast<BlockAddress>(II->getOperand(1)); 1171ed65febSNathan Gauër return BA->getBasicBlock(); 1181ed65febSNathan Gauër } 1191ed65febSNathan Gauër 1201ed65febSNathan Gauër // Returns true if Header has one merge instruction which designated Merge as 1211ed65febSNathan Gauër // merge block. 1221ed65febSNathan Gauër bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) { 1231ed65febSNathan Gauër for (auto &I : Header) { 1241ed65febSNathan Gauër BasicBlock *MB = getDesignatedMergeBlock(&I); 1251ed65febSNathan Gauër if (MB == &Merge) 1261ed65febSNathan Gauër return true; 1271ed65febSNathan Gauër } 1281ed65febSNathan Gauër return false; 1291ed65febSNathan Gauër } 1301ed65febSNathan Gauër 1311ed65febSNathan Gauër // Returns true if the BB has one OpLoopMerge instruction. 1321ed65febSNathan Gauër bool hasLoopMergeInstruction(BasicBlock &BB) { 1331ed65febSNathan Gauër for (auto &I : BB) 1341ed65febSNathan Gauër if (getDesignatedContinueBlock(&I)) 1351ed65febSNathan Gauër return true; 1361ed65febSNathan Gauër return false; 1371ed65febSNathan Gauër } 1381ed65febSNathan Gauër 1391ed65febSNathan Gauër // Returns true is I is an OpSelectionMerge or OpLoopMerge instruction, false 1401ed65febSNathan Gauër // otherwise. 1411ed65febSNathan Gauër bool isMergeInstruction(Instruction *I) { 1421ed65febSNathan Gauër return getDesignatedMergeBlock(I) != nullptr; 1431ed65febSNathan Gauër } 1441ed65febSNathan Gauër 1451ed65febSNathan Gauër // Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge 1461ed65febSNathan Gauër // instruction. 1471ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) { 1481ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> Output; 1491ed65febSNathan Gauër for (BasicBlock &BB : F) { 1501ed65febSNathan Gauër for (Instruction &I : BB) { 1511ed65febSNathan Gauër if (getDesignatedMergeBlock(&I) != nullptr) 1521ed65febSNathan Gauër Output.insert(&BB); 1531ed65febSNathan Gauër } 1541ed65febSNathan Gauër } 1551ed65febSNathan Gauër return Output; 1561ed65febSNathan Gauër } 1571ed65febSNathan Gauër 1581ed65febSNathan Gauër // Returns all basic blocks in |F| referenced by at least 1 1591ed65febSNathan Gauër // OpSelectionMerge/OpLoopMerge instruction. 1601ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) { 1611ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> Output; 1621ed65febSNathan Gauër for (BasicBlock &BB : F) { 1631ed65febSNathan Gauër for (Instruction &I : BB) { 1641ed65febSNathan Gauër BasicBlock *MB = getDesignatedMergeBlock(&I); 1651ed65febSNathan Gauër if (MB != nullptr) 1661ed65febSNathan Gauër Output.insert(MB); 1671ed65febSNathan Gauër } 1681ed65febSNathan Gauër } 1691ed65febSNathan Gauër return Output; 1701ed65febSNathan Gauër } 1711ed65febSNathan Gauër 1721ed65febSNathan Gauër // Return all the merge instructions contained in BB. 1731ed65febSNathan Gauër // Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge 1741ed65febSNathan Gauër // instruction, but this can happen while we structurize the CFG. 1751ed65febSNathan Gauër std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) { 1761ed65febSNathan Gauër std::vector<Instruction *> Output; 1771ed65febSNathan Gauër for (Instruction &I : BB) 1781ed65febSNathan Gauër if (isMergeInstruction(&I)) 1791ed65febSNathan Gauër Output.push_back(&I); 1801ed65febSNathan Gauër return Output; 1811ed65febSNathan Gauër } 1821ed65febSNathan Gauër 1831ed65febSNathan Gauër // Returns all basic blocks in |F| referenced as continue target by at least 1 1841ed65febSNathan Gauër // OpLoopMerge instruction. 1851ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) { 1861ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 2> Output; 1871ed65febSNathan Gauër for (BasicBlock &BB : F) { 1881ed65febSNathan Gauër for (Instruction &I : BB) { 1891ed65febSNathan Gauër BasicBlock *MB = getDesignatedContinueBlock(&I); 1901ed65febSNathan Gauër if (MB != nullptr) 1911ed65febSNathan Gauër Output.insert(MB); 1921ed65febSNathan Gauër } 1931ed65febSNathan Gauër } 1941ed65febSNathan Gauër return Output; 1951ed65febSNathan Gauër } 1961ed65febSNathan Gauër 1971ed65febSNathan Gauër // Do a preorder traversal of the CFG starting from the BB |Start|. 1981ed65febSNathan Gauër // point. Calls |op| on each basic block encountered during the traversal. 1991ed65febSNathan Gauër void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) { 2001ed65febSNathan Gauër std::stack<BasicBlock *> ToVisit; 2011ed65febSNathan Gauër SmallPtrSet<BasicBlock *, 8> Seen; 2021ed65febSNathan Gauër 2031ed65febSNathan Gauër ToVisit.push(&Start); 2041ed65febSNathan Gauër Seen.insert(ToVisit.top()); 2051ed65febSNathan Gauër while (ToVisit.size() != 0) { 2061ed65febSNathan Gauër BasicBlock *BB = ToVisit.top(); 2071ed65febSNathan Gauër ToVisit.pop(); 2081ed65febSNathan Gauër 2091ed65febSNathan Gauër if (!op(BB)) 2101ed65febSNathan Gauër continue; 2111ed65febSNathan Gauër 2121ed65febSNathan Gauër for (auto Succ : successors(BB)) { 2131ed65febSNathan Gauër if (Seen.contains(Succ)) 2141ed65febSNathan Gauër continue; 2151ed65febSNathan Gauër ToVisit.push(Succ); 2161ed65febSNathan Gauër Seen.insert(Succ); 2171ed65febSNathan Gauër } 2181ed65febSNathan Gauër } 2191ed65febSNathan Gauër } 2201ed65febSNathan Gauër 2211ed65febSNathan Gauër // Replaces the conditional and unconditional branch targets of |BB| by 2221ed65febSNathan Gauër // |NewTarget| if the target was |OldTarget|. This function also makes sure the 2231ed65febSNathan Gauër // associated merge instruction gets updated accordingly. 2241ed65febSNathan Gauër void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, 2251ed65febSNathan Gauër BasicBlock *NewTarget) { 2261ed65febSNathan Gauër auto *BI = cast<BranchInst>(BB->getTerminator()); 2271ed65febSNathan Gauër 2281ed65febSNathan Gauër // 1. Replace all matching successors. 2291ed65febSNathan Gauër for (size_t i = 0; i < BI->getNumSuccessors(); i++) { 2301ed65febSNathan Gauër if (BI->getSuccessor(i) == OldTarget) 2311ed65febSNathan Gauër BI->setSuccessor(i, NewTarget); 2321ed65febSNathan Gauër } 2331ed65febSNathan Gauër 2341ed65febSNathan Gauër // Branch was unconditional, no fixup required. 2351ed65febSNathan Gauër if (BI->isUnconditional()) 2361ed65febSNathan Gauër return; 2371ed65febSNathan Gauër 2381ed65febSNathan Gauër // Branch had 2 successors, maybe now both are the same? 2391ed65febSNathan Gauër if (BI->getSuccessor(0) != BI->getSuccessor(1)) 2401ed65febSNathan Gauër return; 2411ed65febSNathan Gauër 2421ed65febSNathan Gauër // Note: we may end up here because the original IR had such branches. 2431ed65febSNathan Gauër // This means Target is not necessarily equal to NewTarget. 2441ed65febSNathan Gauër IRBuilder<> Builder(BB); 2451ed65febSNathan Gauër Builder.SetInsertPoint(BI); 2461ed65febSNathan Gauër Builder.CreateBr(BI->getSuccessor(0)); 2471ed65febSNathan Gauër BI->eraseFromParent(); 2481ed65febSNathan Gauër 2491ed65febSNathan Gauër // The branch was the only instruction, nothing else to do. 2501ed65febSNathan Gauër if (BB->size() == 1) 2511ed65febSNathan Gauër return; 2521ed65febSNathan Gauër 2531ed65febSNathan Gauër // Otherwise, we need to check: was there an OpSelectionMerge before this 2541ed65febSNathan Gauër // branch? If we removed the OpBranchConditional, we must also remove the 2551ed65febSNathan Gauër // OpSelectionMerge. This is not valid for OpLoopMerge: 2561ed65febSNathan Gauër IntrinsicInst *II = 2571ed65febSNathan Gauër dyn_cast<IntrinsicInst>(BB->getTerminator()->getPrevNode()); 2581ed65febSNathan Gauër if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge) 2591ed65febSNathan Gauër return; 2601ed65febSNathan Gauër 2611ed65febSNathan Gauër Constant *C = cast<Constant>(II->getOperand(0)); 2621ed65febSNathan Gauër II->eraseFromParent(); 2631ed65febSNathan Gauër if (!C->isConstantUsed()) 2641ed65febSNathan Gauër C->destroyConstant(); 2651ed65febSNathan Gauër } 2661ed65febSNathan Gauër 2671ed65febSNathan Gauër // Replaces the target of branch instruction in |BB| with |NewTarget| if it 2681ed65febSNathan Gauër // was |OldTarget|. This function also fixes the associated merge instruction. 2691ed65febSNathan Gauër // Note: this function does not simplify branching instructions, it only updates 2701ed65febSNathan Gauër // targets. See also: simplifyBranches. 2711ed65febSNathan Gauër void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, 2721ed65febSNathan Gauër BasicBlock *NewTarget) { 2731ed65febSNathan Gauër auto *T = BB->getTerminator(); 2741ed65febSNathan Gauër if (isa<ReturnInst>(T)) 2751ed65febSNathan Gauër return; 2761ed65febSNathan Gauër 2771ed65febSNathan Gauër if (isa<BranchInst>(T)) 2781ed65febSNathan Gauër return replaceIfBranchTargets(BB, OldTarget, NewTarget); 2791ed65febSNathan Gauër 2801ed65febSNathan Gauër if (auto *SI = dyn_cast<SwitchInst>(T)) { 2811ed65febSNathan Gauër for (size_t i = 0; i < SI->getNumSuccessors(); i++) { 2821ed65febSNathan Gauër if (SI->getSuccessor(i) == OldTarget) 2831ed65febSNathan Gauër SI->setSuccessor(i, NewTarget); 2841ed65febSNathan Gauër } 2851ed65febSNathan Gauër return; 2861ed65febSNathan Gauër } 2871ed65febSNathan Gauër 2881ed65febSNathan Gauër assert(false && "Unhandled terminator type."); 2891ed65febSNathan Gauër } 2901ed65febSNathan Gauër 2911ed65febSNathan Gauër } // anonymous namespace 2921ed65febSNathan Gauër 2931ed65febSNathan Gauër // Given a reducible CFG, produces a structurized CFG in the SPIR-V sense, 2941ed65febSNathan Gauër // adding merge instructions when required. 2951ed65febSNathan Gauër class SPIRVStructurizer : public FunctionPass { 2961ed65febSNathan Gauër 2971ed65febSNathan Gauër struct DivergentConstruct; 2981ed65febSNathan Gauër // Represents a list of condition/loops/switch constructs. 2991ed65febSNathan Gauër // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of 3001ed65febSNathan Gauër // constructs. 3011ed65febSNathan Gauër using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>; 3021ed65febSNathan Gauër 3031ed65febSNathan Gauër // Represents a divergent construct in the SPIR-V sense. 3041ed65febSNathan Gauër // Such constructs are represented by a header (entry), a merge block (exit), 3051ed65febSNathan Gauër // and possibly a continue block (back-edge). A construct can contain other 3061ed65febSNathan Gauër // constructs, but their boundaries do not cross. 3071ed65febSNathan Gauër struct DivergentConstruct { 3081ed65febSNathan Gauër BasicBlock *Header = nullptr; 3091ed65febSNathan Gauër BasicBlock *Merge = nullptr; 3101ed65febSNathan Gauër BasicBlock *Continue = nullptr; 3111ed65febSNathan Gauër 3121ed65febSNathan Gauër DivergentConstruct *Parent = nullptr; 3131ed65febSNathan Gauër ConstructList Children; 3141ed65febSNathan Gauër }; 3151ed65febSNathan Gauër 3161ed65febSNathan Gauër // An helper class to clean the construct boundaries. 3171ed65febSNathan Gauër // It is used to gather the list of blocks that should belong to each 3181ed65febSNathan Gauër // divergent construct, and possibly modify CFG edges when exits would cross 3191ed65febSNathan Gauër // the boundary of multiple constructs. 3201ed65febSNathan Gauër struct Splitter { 3211ed65febSNathan Gauër Function &F; 3221ed65febSNathan Gauër LoopInfo &LI; 3231ed65febSNathan Gauër DomTreeBuilder::BBDomTree DT; 3241ed65febSNathan Gauër DomTreeBuilder::BBPostDomTree PDT; 3251ed65febSNathan Gauër 3261ed65febSNathan Gauër Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); } 3271ed65febSNathan Gauër 3281ed65febSNathan Gauër void invalidate() { 3291ed65febSNathan Gauër PDT.recalculate(F); 3301ed65febSNathan Gauër DT.recalculate(F); 3311ed65febSNathan Gauër } 3321ed65febSNathan Gauër 3331ed65febSNathan Gauër // Returns the list of blocks that belong to a SPIR-V loop construct, 3341ed65febSNathan Gauër // including the continue construct. 3351ed65febSNathan Gauër std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header, 3361ed65febSNathan Gauër BasicBlock *Merge) { 3371ed65febSNathan Gauër assert(DT.dominates(Header, Merge)); 3381ed65febSNathan Gauër std::vector<BasicBlock *> Output; 3391ed65febSNathan Gauër partialOrderVisit(*Header, [&](BasicBlock *BB) { 3401ed65febSNathan Gauër if (BB == Merge) 3411ed65febSNathan Gauër return false; 3421ed65febSNathan Gauër if (DT.dominates(Merge, BB) || !DT.dominates(Header, BB)) 3431ed65febSNathan Gauër return false; 3441ed65febSNathan Gauër Output.push_back(BB); 3451ed65febSNathan Gauër return true; 3461ed65febSNathan Gauër }); 3471ed65febSNathan Gauër return Output; 3481ed65febSNathan Gauër } 3491ed65febSNathan Gauër 3501ed65febSNathan Gauër // Returns the list of blocks that belong to a SPIR-V selection construct. 3511ed65febSNathan Gauër std::vector<BasicBlock *> 3521ed65febSNathan Gauër getSelectionConstructBlocks(DivergentConstruct *Node) { 3531ed65febSNathan Gauër assert(DT.dominates(Node->Header, Node->Merge)); 3541ed65febSNathan Gauër BlockSet OutsideBlocks; 3551ed65febSNathan Gauër OutsideBlocks.insert(Node->Merge); 3561ed65febSNathan Gauër 3571ed65febSNathan Gauër for (DivergentConstruct *It = Node->Parent; It != nullptr; 3581ed65febSNathan Gauër It = It->Parent) { 3591ed65febSNathan Gauër OutsideBlocks.insert(It->Merge); 3601ed65febSNathan Gauër if (It->Continue) 3611ed65febSNathan Gauër OutsideBlocks.insert(It->Continue); 3621ed65febSNathan Gauër } 3631ed65febSNathan Gauër 3641ed65febSNathan Gauër std::vector<BasicBlock *> Output; 3651ed65febSNathan Gauër partialOrderVisit(*Node->Header, [&](BasicBlock *BB) { 3661ed65febSNathan Gauër if (OutsideBlocks.count(BB) != 0) 3671ed65febSNathan Gauër return false; 3681ed65febSNathan Gauër if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB)) 3691ed65febSNathan Gauër return false; 3701ed65febSNathan Gauër Output.push_back(BB); 3711ed65febSNathan Gauër return true; 3721ed65febSNathan Gauër }); 3731ed65febSNathan Gauër return Output; 3741ed65febSNathan Gauër } 3751ed65febSNathan Gauër 3761ed65febSNathan Gauër // Returns the list of blocks that belong to a SPIR-V switch construct. 3771ed65febSNathan Gauër std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header, 3781ed65febSNathan Gauër BasicBlock *Merge) { 3791ed65febSNathan Gauër assert(DT.dominates(Header, Merge)); 3801ed65febSNathan Gauër 3811ed65febSNathan Gauër std::vector<BasicBlock *> Output; 3821ed65febSNathan Gauër partialOrderVisit(*Header, [&](BasicBlock *BB) { 3831ed65febSNathan Gauër // the blocks structurally dominated by a switch header, 3841ed65febSNathan Gauër if (!DT.dominates(Header, BB)) 3851ed65febSNathan Gauër return false; 3861ed65febSNathan Gauër // excluding blocks structurally dominated by the switch header’s merge 3871ed65febSNathan Gauër // block. 3881ed65febSNathan Gauër if (DT.dominates(Merge, BB) || BB == Merge) 3891ed65febSNathan Gauër return false; 3901ed65febSNathan Gauër Output.push_back(BB); 3911ed65febSNathan Gauër return true; 3921ed65febSNathan Gauër }); 3931ed65febSNathan Gauër return Output; 3941ed65febSNathan Gauër } 3951ed65febSNathan Gauër 3961ed65febSNathan Gauër // Returns the list of blocks that belong to a SPIR-V case construct. 3971ed65febSNathan Gauër std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target, 3981ed65febSNathan Gauër BasicBlock *Merge) { 3991ed65febSNathan Gauër assert(DT.dominates(Target, Merge)); 4001ed65febSNathan Gauër 4011ed65febSNathan Gauër std::vector<BasicBlock *> Output; 4021ed65febSNathan Gauër partialOrderVisit(*Target, [&](BasicBlock *BB) { 4031ed65febSNathan Gauër // the blocks structurally dominated by an OpSwitch Target or Default 4041ed65febSNathan Gauër // block 4051ed65febSNathan Gauër if (!DT.dominates(Target, BB)) 4061ed65febSNathan Gauër return false; 4071ed65febSNathan Gauër // excluding the blocks structurally dominated by the OpSwitch 4081ed65febSNathan Gauër // construct’s corresponding merge block. 4091ed65febSNathan Gauër if (DT.dominates(Merge, BB) || BB == Merge) 4101ed65febSNathan Gauër return false; 4111ed65febSNathan Gauër Output.push_back(BB); 4121ed65febSNathan Gauër return true; 4131ed65febSNathan Gauër }); 4141ed65febSNathan Gauër return Output; 4151ed65febSNathan Gauër } 4161ed65febSNathan Gauër 4171ed65febSNathan Gauër // Splits the given edges by recreating proxy nodes so that the destination 418cba70550SNathan Gauër // has unique incoming edges from this region. 4191ed65febSNathan Gauër // 4201ed65febSNathan Gauër // clang-format off 4211ed65febSNathan Gauër // 4221ed65febSNathan Gauër // In SPIR-V, constructs must have a single exit/merge. 4231ed65febSNathan Gauër // Given nodes A and B in the construct, a node C outside, and the following edges. 4241ed65febSNathan Gauër // A -> C 4251ed65febSNathan Gauër // B -> C 4261ed65febSNathan Gauër // 4271ed65febSNathan Gauër // In such cases, we must create a new exit node D, that belong to the construct to make is viable: 4281ed65febSNathan Gauër // A -> D -> C 4291ed65febSNathan Gauër // B -> D -> C 4301ed65febSNathan Gauër // 431cba70550SNathan Gauër // This is fine (assuming C has no PHI nodes), but requires handling the merge instruction here. 432cba70550SNathan Gauër // By adding a proxy node, we create a regular divergent shape which can easily be regularized later on. 4331ed65febSNathan Gauër // A -> D -> D1 -> C 4341ed65febSNathan Gauër // B -> D -> D2 -> C 4351ed65febSNathan Gauër // 436cba70550SNathan Gauër // A, B, D belongs to the construct. D is the exit. D1 and D2 are empty. 4371ed65febSNathan Gauër // 4381ed65febSNathan Gauër // clang-format on 4391ed65febSNathan Gauër std::vector<Edge> 4401ed65febSNathan Gauër createAliasBlocksForComplexEdges(std::vector<Edge> Edges) { 441cba70550SNathan Gauër std::unordered_set<BasicBlock *> Seen; 4421ed65febSNathan Gauër std::vector<Edge> Output; 4431ed65febSNathan Gauër Output.reserve(Edges.size()); 4441ed65febSNathan Gauër 4451ed65febSNathan Gauër for (auto &[Src, Dst] : Edges) { 446cba70550SNathan Gauër auto [Iterator, Inserted] = Seen.insert(Src); 447cba70550SNathan Gauër if (!Inserted) { 448cba70550SNathan Gauër // Src already a source node. Cannot have 2 edges from A to B. 449cba70550SNathan Gauër // Creating alias source block. 450cba70550SNathan Gauër BasicBlock *NewSrc = BasicBlock::Create( 451cba70550SNathan Gauër F.getContext(), Src->getName() + ".new.src", &F); 4521ed65febSNathan Gauër replaceBranchTargets(Src, Dst, NewSrc); 4531ed65febSNathan Gauër IRBuilder<> Builder(NewSrc); 4541ed65febSNathan Gauër Builder.CreateBr(Dst); 455cba70550SNathan Gauër Src = NewSrc; 456cba70550SNathan Gauër } 4571ed65febSNathan Gauër 458cba70550SNathan Gauër Output.emplace_back(Src, Dst); 4591ed65febSNathan Gauër } 4601ed65febSNathan Gauër 4611ed65febSNathan Gauër return Output; 4621ed65febSNathan Gauër } 4631ed65febSNathan Gauër 464cba70550SNathan Gauër AllocaInst *CreateVariable(Function &F, Type *Type, 465cba70550SNathan Gauër BasicBlock::iterator Position) { 466cba70550SNathan Gauër const DataLayout &DL = F.getDataLayout(); 467cba70550SNathan Gauër return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg", 468cba70550SNathan Gauër Position); 469cba70550SNathan Gauër } 470cba70550SNathan Gauër 4711ed65febSNathan Gauër // Given a construct defined by |Header|, and a list of exiting edges 4721ed65febSNathan Gauër // |Edges|, creates a new single exit node, fixing up those edges. 4731ed65febSNathan Gauër BasicBlock *createSingleExitNode(BasicBlock *Header, 4741ed65febSNathan Gauër std::vector<Edge> &Edges) { 475cba70550SNathan Gauër 476cba70550SNathan Gauër std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges); 4771ed65febSNathan Gauër 4781ed65febSNathan Gauër std::vector<BasicBlock *> Dsts; 4791ed65febSNathan Gauër std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex; 480cba70550SNathan Gauër auto NewExit = BasicBlock::Create(F.getContext(), 481cba70550SNathan Gauër Header->getName() + ".new.exit", &F); 482cba70550SNathan Gauër IRBuilder<> ExitBuilder(NewExit); 4831ed65febSNathan Gauër for (auto &[Src, Dst] : FixedEdges) { 4841ed65febSNathan Gauër if (DstToIndex.count(Dst) != 0) 4851ed65febSNathan Gauër continue; 4861ed65febSNathan Gauër DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size())); 4871ed65febSNathan Gauër Dsts.push_back(Dst); 4881ed65febSNathan Gauër } 4891ed65febSNathan Gauër 4901ed65febSNathan Gauër if (Dsts.size() == 1) { 4911ed65febSNathan Gauër for (auto &[Src, Dst] : FixedEdges) { 4921ed65febSNathan Gauër replaceBranchTargets(Src, Dst, NewExit); 4931ed65febSNathan Gauër } 4941ed65febSNathan Gauër ExitBuilder.CreateBr(Dsts[0]); 4951ed65febSNathan Gauër return NewExit; 4961ed65febSNathan Gauër } 4971ed65febSNathan Gauër 498cba70550SNathan Gauër AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(), 499cba70550SNathan Gauër F.begin()->getFirstInsertionPt()); 5001ed65febSNathan Gauër for (auto &[Src, Dst] : FixedEdges) { 501cba70550SNathan Gauër IRBuilder<> B2(Src); 502cba70550SNathan Gauër B2.SetInsertPoint(Src->getFirstInsertionPt()); 503cba70550SNathan Gauër B2.CreateStore(DstToIndex[Dst], Variable); 5041ed65febSNathan Gauër replaceBranchTargets(Src, Dst, NewExit); 5051ed65febSNathan Gauër } 5061ed65febSNathan Gauër 507cba70550SNathan Gauër llvm::Value *Load = 508cba70550SNathan Gauër ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable); 509cba70550SNathan Gauër 5101ed65febSNathan Gauër // If we can avoid an OpSwitch, generate an OpBranch. Reason is some 5111ed65febSNathan Gauër // OpBranch are allowed to exist without a new OpSelectionMerge if one of 5121ed65febSNathan Gauër // the branch is the parent's merge node, while OpSwitches are not. 5131ed65febSNathan Gauër if (Dsts.size() == 2) { 514cba70550SNathan Gauër Value *Condition = 515cba70550SNathan Gauër ExitBuilder.CreateCmp(CmpInst::ICMP_EQ, DstToIndex[Dsts[0]], Load); 5161ed65febSNathan Gauër ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]); 5171ed65febSNathan Gauër return NewExit; 5181ed65febSNathan Gauër } 5191ed65febSNathan Gauër 520cba70550SNathan Gauër SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1); 5211ed65febSNathan Gauër for (auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) { 5221ed65febSNathan Gauër Sw->addCase(DstToIndex[*It], *It); 5231ed65febSNathan Gauër } 5241ed65febSNathan Gauër return NewExit; 5251ed65febSNathan Gauër } 5261ed65febSNathan Gauër }; 5271ed65febSNathan Gauër 5281ed65febSNathan Gauër /// Create a value in BB set to the value associated with the branch the block 5291ed65febSNathan Gauër /// terminator will take. 5301ed65febSNathan Gauër Value *createExitVariable( 5311ed65febSNathan Gauër BasicBlock *BB, 5321ed65febSNathan Gauër const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) { 5331ed65febSNathan Gauër auto *T = BB->getTerminator(); 5341ed65febSNathan Gauër if (isa<ReturnInst>(T)) 5351ed65febSNathan Gauër return nullptr; 5361ed65febSNathan Gauër 5371ed65febSNathan Gauër IRBuilder<> Builder(BB); 5381ed65febSNathan Gauër Builder.SetInsertPoint(T); 5391ed65febSNathan Gauër 5401ed65febSNathan Gauër if (auto *BI = dyn_cast<BranchInst>(T)) { 5411ed65febSNathan Gauër 5421ed65febSNathan Gauër BasicBlock *LHSTarget = BI->getSuccessor(0); 5431ed65febSNathan Gauër BasicBlock *RHSTarget = 5441ed65febSNathan Gauër BI->isConditional() ? BI->getSuccessor(1) : nullptr; 5451ed65febSNathan Gauër 5461ed65febSNathan Gauër Value *LHS = TargetToValue.count(LHSTarget) != 0 5471ed65febSNathan Gauër ? TargetToValue.at(LHSTarget) 5481ed65febSNathan Gauër : nullptr; 5491ed65febSNathan Gauër Value *RHS = TargetToValue.count(RHSTarget) != 0 5501ed65febSNathan Gauër ? TargetToValue.at(RHSTarget) 5511ed65febSNathan Gauër : nullptr; 5521ed65febSNathan Gauër 5531ed65febSNathan Gauër if (LHS == nullptr || RHS == nullptr) 5541ed65febSNathan Gauër return LHS == nullptr ? RHS : LHS; 5551ed65febSNathan Gauër return Builder.CreateSelect(BI->getCondition(), LHS, RHS); 5561ed65febSNathan Gauër } 5571ed65febSNathan Gauër 5581ed65febSNathan Gauër // TODO: add support for switch cases. 5591ed65febSNathan Gauër llvm_unreachable("Unhandled terminator type."); 5601ed65febSNathan Gauër } 5611ed65febSNathan Gauër 5621ed65febSNathan Gauër // Creates a new basic block in F with a single OpUnreachable instruction. 5631ed65febSNathan Gauër BasicBlock *CreateUnreachable(Function &F) { 564cba70550SNathan Gauër BasicBlock *BB = BasicBlock::Create(F.getContext(), "unreachable", &F); 5651ed65febSNathan Gauër IRBuilder<> Builder(BB); 5661ed65febSNathan Gauër Builder.CreateUnreachable(); 5671ed65febSNathan Gauër return BB; 5681ed65febSNathan Gauër } 5691ed65febSNathan Gauër 5701ed65febSNathan Gauër // Add OpLoopMerge instruction on cycles. 5711ed65febSNathan Gauër bool addMergeForLoops(Function &F) { 5721ed65febSNathan Gauër LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 5731ed65febSNathan Gauër auto *TopLevelRegion = 5741ed65febSNathan Gauër getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() 5751ed65febSNathan Gauër .getRegionInfo() 5761ed65febSNathan Gauër .getTopLevelRegion(); 5771ed65febSNathan Gauër 5781ed65febSNathan Gauër bool Modified = false; 5791ed65febSNathan Gauër for (auto &BB : F) { 5801ed65febSNathan Gauër // Not a loop header. Ignoring for now. 5811ed65febSNathan Gauër if (!LI.isLoopHeader(&BB)) 5821ed65febSNathan Gauër continue; 5831ed65febSNathan Gauër auto *L = LI.getLoopFor(&BB); 5841ed65febSNathan Gauër 5851ed65febSNathan Gauër // This loop header is not the entrance of a convergence region. Ignoring 5861ed65febSNathan Gauër // this block. 5871ed65febSNathan Gauër auto *CR = getRegionForHeader(TopLevelRegion, &BB); 5881ed65febSNathan Gauër if (CR == nullptr) 5891ed65febSNathan Gauër continue; 5901ed65febSNathan Gauër 5911ed65febSNathan Gauër IRBuilder<> Builder(&BB); 5921ed65febSNathan Gauër 5931ed65febSNathan Gauër auto *Merge = getExitFor(CR); 5941ed65febSNathan Gauër // We are indeed in a loop, but there are no exits (infinite loop). 5951ed65febSNathan Gauër // This could be caused by a bad shader, but also could be an artifact 5961ed65febSNathan Gauër // from an earlier optimization. It is not always clear if structurally 5971ed65febSNathan Gauër // reachable means runtime reachable, so we cannot error-out. What we must 5981ed65febSNathan Gauër // do however is to make is legal on the SPIR-V point of view, hence 5991ed65febSNathan Gauër // adding an unreachable merge block. 6001ed65febSNathan Gauër if (Merge == nullptr) { 6011ed65febSNathan Gauër BranchInst *Br = cast<BranchInst>(BB.getTerminator()); 6021ed65febSNathan Gauër assert(Br && 6031ed65febSNathan Gauër "This assumes the branch is not a switch. Maybe that's wrong?"); 6041ed65febSNathan Gauër assert(cast<BranchInst>(BB.getTerminator())->isUnconditional()); 6051ed65febSNathan Gauër 6061ed65febSNathan Gauër Merge = CreateUnreachable(F); 6071ed65febSNathan Gauër Builder.SetInsertPoint(Br); 6081ed65febSNathan Gauër Builder.CreateCondBr(Builder.getFalse(), Merge, Br->getSuccessor(0)); 6091ed65febSNathan Gauër Br->eraseFromParent(); 6101ed65febSNathan Gauër } 6111ed65febSNathan Gauër 6121ed65febSNathan Gauër auto *Continue = L->getLoopLatch(); 6131ed65febSNathan Gauër 6141ed65febSNathan Gauër Builder.SetInsertPoint(BB.getTerminator()); 6151ed65febSNathan Gauër auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge); 6161ed65febSNathan Gauër auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue); 6171ed65febSNathan Gauër SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress}; 6181ed65febSNathan Gauër 6191ed65febSNathan Gauër Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {}, {Args}); 6201ed65febSNathan Gauër Modified = true; 6211ed65febSNathan Gauër } 6221ed65febSNathan Gauër 6231ed65febSNathan Gauër return Modified; 6241ed65febSNathan Gauër } 6251ed65febSNathan Gauër 6261ed65febSNathan Gauër // Adds an OpSelectionMerge to the immediate dominator or each node with an 6271ed65febSNathan Gauër // in-degree of 2 or more which is not already the merge target of an 6281ed65febSNathan Gauër // OpLoopMerge/OpSelectionMerge. 6291ed65febSNathan Gauër bool addMergeForNodesWithMultiplePredecessors(Function &F) { 6301ed65febSNathan Gauër DomTreeBuilder::BBDomTree DT; 6311ed65febSNathan Gauër DT.recalculate(F); 6321ed65febSNathan Gauër 6331ed65febSNathan Gauër bool Modified = false; 6341ed65febSNathan Gauër for (auto &BB : F) { 6351ed65febSNathan Gauër if (pred_size(&BB) <= 1) 6361ed65febSNathan Gauër continue; 6371ed65febSNathan Gauër 6381ed65febSNathan Gauër if (hasLoopMergeInstruction(BB) && pred_size(&BB) <= 2) 6391ed65febSNathan Gauër continue; 6401ed65febSNathan Gauër 6411ed65febSNathan Gauër assert(DT.getNode(&BB)->getIDom()); 6421ed65febSNathan Gauër BasicBlock *Header = DT.getNode(&BB)->getIDom()->getBlock(); 6431ed65febSNathan Gauër 6441ed65febSNathan Gauër if (isDefinedAsSelectionMergeBy(*Header, BB)) 6451ed65febSNathan Gauër continue; 6461ed65febSNathan Gauër 6471ed65febSNathan Gauër IRBuilder<> Builder(Header); 6481ed65febSNathan Gauër Builder.SetInsertPoint(Header->getTerminator()); 6491ed65febSNathan Gauër 6501ed65febSNathan Gauër auto MergeAddress = BlockAddress::get(BB.getParent(), &BB); 651380bb51bSjoaosaffran createOpSelectMerge(&Builder, MergeAddress); 6521ed65febSNathan Gauër 6531ed65febSNathan Gauër Modified = true; 6541ed65febSNathan Gauër } 6551ed65febSNathan Gauër 6561ed65febSNathan Gauër return Modified; 6571ed65febSNathan Gauër } 6581ed65febSNathan Gauër 6591ed65febSNathan Gauër // When a block has multiple OpSelectionMerge/OpLoopMerge instructions, sorts 6601ed65febSNathan Gauër // them to put the "largest" first. A merge instruction is defined as larger 6611ed65febSNathan Gauër // than another when its target merge block post-dominates the other target's 6621ed65febSNathan Gauër // merge block. (This ordering should match the nesting ordering of the source 6631ed65febSNathan Gauër // HLSL). 6641ed65febSNathan Gauër bool sortSelectionMerge(Function &F, BasicBlock &Block) { 6651ed65febSNathan Gauër std::vector<Instruction *> MergeInstructions; 6661ed65febSNathan Gauër for (Instruction &I : Block) 6671ed65febSNathan Gauër if (isMergeInstruction(&I)) 6681ed65febSNathan Gauër MergeInstructions.push_back(&I); 6691ed65febSNathan Gauër 6701ed65febSNathan Gauër if (MergeInstructions.size() <= 1) 6711ed65febSNathan Gauër return false; 6721ed65febSNathan Gauër 6731ed65febSNathan Gauër Instruction *InsertionPoint = *MergeInstructions.begin(); 6741ed65febSNathan Gauër 6751ed65febSNathan Gauër PartialOrderingVisitor Visitor(F); 6761ed65febSNathan Gauër std::sort(MergeInstructions.begin(), MergeInstructions.end(), 6771ed65febSNathan Gauër [&Visitor](Instruction *Left, Instruction *Right) { 6781ed65febSNathan Gauër if (Left == Right) 6791ed65febSNathan Gauër return false; 6801ed65febSNathan Gauër BasicBlock *RightMerge = getDesignatedMergeBlock(Right); 6811ed65febSNathan Gauër BasicBlock *LeftMerge = getDesignatedMergeBlock(Left); 6821ed65febSNathan Gauër return !Visitor.compare(RightMerge, LeftMerge); 6831ed65febSNathan Gauër }); 6841ed65febSNathan Gauër 6851ed65febSNathan Gauër for (Instruction *I : MergeInstructions) { 686*304a9909SJeremy Morse I->moveBefore(InsertionPoint->getIterator()); 6871ed65febSNathan Gauër InsertionPoint = I; 6881ed65febSNathan Gauër } 6891ed65febSNathan Gauër 6901ed65febSNathan Gauër return true; 6911ed65febSNathan Gauër } 6921ed65febSNathan Gauër 6931ed65febSNathan Gauër // Sorts selection merge headers in |F|. 6941ed65febSNathan Gauër // A is sorted before B if the merge block designated by B is an ancestor of 6951ed65febSNathan Gauër // the one designated by A. 6961ed65febSNathan Gauër bool sortSelectionMergeHeaders(Function &F) { 6971ed65febSNathan Gauër bool Modified = false; 6981ed65febSNathan Gauër for (BasicBlock &BB : F) { 6991ed65febSNathan Gauër Modified |= sortSelectionMerge(F, BB); 7001ed65febSNathan Gauër } 7011ed65febSNathan Gauër return Modified; 7021ed65febSNathan Gauër } 7031ed65febSNathan Gauër 7041ed65febSNathan Gauër // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge 7051ed65febSNathan Gauër // instructions so each basic block contains only a single merge instruction. 7061ed65febSNathan Gauër bool splitBlocksWithMultipleHeaders(Function &F) { 7071ed65febSNathan Gauër std::stack<BasicBlock *> Work; 7081ed65febSNathan Gauër for (auto &BB : F) { 7091ed65febSNathan Gauër std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB); 7101ed65febSNathan Gauër if (MergeInstructions.size() <= 1) 7111ed65febSNathan Gauër continue; 7121ed65febSNathan Gauër Work.push(&BB); 7131ed65febSNathan Gauër } 7141ed65febSNathan Gauër 7151ed65febSNathan Gauër const bool Modified = Work.size() > 0; 7161ed65febSNathan Gauër while (Work.size() > 0) { 7171ed65febSNathan Gauër BasicBlock *Header = Work.top(); 7181ed65febSNathan Gauër Work.pop(); 7191ed65febSNathan Gauër 7201ed65febSNathan Gauër std::vector<Instruction *> MergeInstructions = 7211ed65febSNathan Gauër getMergeInstructions(*Header); 7221ed65febSNathan Gauër for (unsigned i = 1; i < MergeInstructions.size(); i++) { 7231ed65febSNathan Gauër BasicBlock *NewBlock = 7241ed65febSNathan Gauër Header->splitBasicBlock(MergeInstructions[i], "new.header"); 7251ed65febSNathan Gauër 7261ed65febSNathan Gauër if (getDesignatedContinueBlock(MergeInstructions[0]) == nullptr) { 7271ed65febSNathan Gauër BasicBlock *Unreachable = CreateUnreachable(F); 7281ed65febSNathan Gauër 7291ed65febSNathan Gauër BranchInst *BI = cast<BranchInst>(Header->getTerminator()); 7301ed65febSNathan Gauër IRBuilder<> Builder(Header); 7311ed65febSNathan Gauër Builder.SetInsertPoint(BI); 7321ed65febSNathan Gauër Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable); 7331ed65febSNathan Gauër BI->eraseFromParent(); 7341ed65febSNathan Gauër } 7351ed65febSNathan Gauër 7361ed65febSNathan Gauër Header = NewBlock; 7371ed65febSNathan Gauër } 7381ed65febSNathan Gauër } 7391ed65febSNathan Gauër 7401ed65febSNathan Gauër return Modified; 7411ed65febSNathan Gauër } 7421ed65febSNathan Gauër 7431ed65febSNathan Gauër // Adds an OpSelectionMerge to each block with an out-degree >= 2 which 7441ed65febSNathan Gauër // doesn't already have an OpSelectionMerge. 7451ed65febSNathan Gauër bool addMergeForDivergentBlocks(Function &F) { 7461ed65febSNathan Gauër DomTreeBuilder::BBPostDomTree PDT; 7471ed65febSNathan Gauër PDT.recalculate(F); 7481ed65febSNathan Gauër bool Modified = false; 7491ed65febSNathan Gauër 7501ed65febSNathan Gauër auto MergeBlocks = getMergeBlocks(F); 7511ed65febSNathan Gauër auto ContinueBlocks = getContinueBlocks(F); 7521ed65febSNathan Gauër 7531ed65febSNathan Gauër for (auto &BB : F) { 7541ed65febSNathan Gauër if (getMergeInstructions(BB).size() != 0) 7551ed65febSNathan Gauër continue; 7561ed65febSNathan Gauër 7571ed65febSNathan Gauër std::vector<BasicBlock *> Candidates; 7581ed65febSNathan Gauër for (BasicBlock *Successor : successors(&BB)) { 7591ed65febSNathan Gauër if (MergeBlocks.contains(Successor)) 7601ed65febSNathan Gauër continue; 7611ed65febSNathan Gauër if (ContinueBlocks.contains(Successor)) 7621ed65febSNathan Gauër continue; 7631ed65febSNathan Gauër Candidates.push_back(Successor); 7641ed65febSNathan Gauër } 7651ed65febSNathan Gauër 7661ed65febSNathan Gauër if (Candidates.size() <= 1) 7671ed65febSNathan Gauër continue; 7681ed65febSNathan Gauër 7691ed65febSNathan Gauër Modified = true; 7701ed65febSNathan Gauër BasicBlock *Merge = Candidates[0]; 7711ed65febSNathan Gauër 7721ed65febSNathan Gauër auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge); 7731ed65febSNathan Gauër IRBuilder<> Builder(&BB); 7741ed65febSNathan Gauër Builder.SetInsertPoint(BB.getTerminator()); 775380bb51bSjoaosaffran createOpSelectMerge(&Builder, MergeAddress); 7761ed65febSNathan Gauër } 7771ed65febSNathan Gauër 7781ed65febSNathan Gauër return Modified; 7791ed65febSNathan Gauër } 7801ed65febSNathan Gauër 7811ed65febSNathan Gauër // Gather all the exit nodes for the construct header by |Header| and 7821ed65febSNathan Gauër // containing the blocks |Construct|. 7831ed65febSNathan Gauër std::vector<Edge> getExitsFrom(const BlockSet &Construct, 7841ed65febSNathan Gauër BasicBlock &Header) { 7851ed65febSNathan Gauër std::vector<Edge> Output; 7861ed65febSNathan Gauër visit(Header, [&](BasicBlock *Item) { 7871ed65febSNathan Gauër if (Construct.count(Item) == 0) 7881ed65febSNathan Gauër return false; 7891ed65febSNathan Gauër 7901ed65febSNathan Gauër for (BasicBlock *Successor : successors(Item)) { 7911ed65febSNathan Gauër if (Construct.count(Successor) == 0) 7921ed65febSNathan Gauër Output.emplace_back(Item, Successor); 7931ed65febSNathan Gauër } 7941ed65febSNathan Gauër return true; 7951ed65febSNathan Gauër }); 7961ed65febSNathan Gauër 7971ed65febSNathan Gauër return Output; 7981ed65febSNathan Gauër } 7991ed65febSNathan Gauër 8001ed65febSNathan Gauër // Build a divergent construct tree searching from |BB|. 8011ed65febSNathan Gauër // If |Parent| is not null, this tree is attached to the parent's tree. 8021ed65febSNathan Gauër void constructDivergentConstruct(BlockSet &Visited, Splitter &S, 8031ed65febSNathan Gauër BasicBlock *BB, DivergentConstruct *Parent) { 8041ed65febSNathan Gauër if (Visited.count(BB) != 0) 8051ed65febSNathan Gauër return; 8061ed65febSNathan Gauër Visited.insert(BB); 8071ed65febSNathan Gauër 8081ed65febSNathan Gauër auto MIS = getMergeInstructions(*BB); 8091ed65febSNathan Gauër if (MIS.size() == 0) { 8101ed65febSNathan Gauër for (BasicBlock *Successor : successors(BB)) 8111ed65febSNathan Gauër constructDivergentConstruct(Visited, S, Successor, Parent); 8121ed65febSNathan Gauër return; 8131ed65febSNathan Gauër } 8141ed65febSNathan Gauër 8151ed65febSNathan Gauër assert(MIS.size() == 1); 8161ed65febSNathan Gauër Instruction *MI = MIS[0]; 8171ed65febSNathan Gauër 8181ed65febSNathan Gauër BasicBlock *Merge = getDesignatedMergeBlock(MI); 8191ed65febSNathan Gauër BasicBlock *Continue = getDesignatedContinueBlock(MI); 8201ed65febSNathan Gauër 8211ed65febSNathan Gauër auto Output = std::make_unique<DivergentConstruct>(); 8221ed65febSNathan Gauër Output->Header = BB; 8231ed65febSNathan Gauër Output->Merge = Merge; 8241ed65febSNathan Gauër Output->Continue = Continue; 8251ed65febSNathan Gauër Output->Parent = Parent; 8261ed65febSNathan Gauër 8271ed65febSNathan Gauër constructDivergentConstruct(Visited, S, Merge, Parent); 8281ed65febSNathan Gauër if (Continue) 8291ed65febSNathan Gauër constructDivergentConstruct(Visited, S, Continue, Output.get()); 8301ed65febSNathan Gauër 8311ed65febSNathan Gauër for (BasicBlock *Successor : successors(BB)) 8321ed65febSNathan Gauër constructDivergentConstruct(Visited, S, Successor, Output.get()); 8331ed65febSNathan Gauër 8341ed65febSNathan Gauër if (Parent) 8351ed65febSNathan Gauër Parent->Children.emplace_back(std::move(Output)); 8361ed65febSNathan Gauër } 8371ed65febSNathan Gauër 8381ed65febSNathan Gauër // Returns the blocks belonging to the divergent construct |Node|. 8391ed65febSNathan Gauër BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) { 8401ed65febSNathan Gauër assert(Node->Header && Node->Merge); 8411ed65febSNathan Gauër 8421ed65febSNathan Gauër if (Node->Continue) { 8431ed65febSNathan Gauër auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge); 8441ed65febSNathan Gauër return BlockSet(LoopBlocks.begin(), LoopBlocks.end()); 8451ed65febSNathan Gauër } 8461ed65febSNathan Gauër 8471ed65febSNathan Gauër auto SelectionBlocks = S.getSelectionConstructBlocks(Node); 8481ed65febSNathan Gauër return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end()); 8491ed65febSNathan Gauër } 8501ed65febSNathan Gauër 8511ed65febSNathan Gauër // Fixup the construct |Node| to respect a set of rules defined by the SPIR-V 8521ed65febSNathan Gauër // spec. 8531ed65febSNathan Gauër bool fixupConstruct(Splitter &S, DivergentConstruct *Node) { 8541ed65febSNathan Gauër bool Modified = false; 8551ed65febSNathan Gauër for (auto &Child : Node->Children) 8561ed65febSNathan Gauër Modified |= fixupConstruct(S, Child.get()); 8571ed65febSNathan Gauër 8581ed65febSNathan Gauër // This construct is the root construct. Does not represent any real 8591ed65febSNathan Gauër // construct, just a way to access the first level of the forest. 8601ed65febSNathan Gauër if (Node->Parent == nullptr) 8611ed65febSNathan Gauër return Modified; 8621ed65febSNathan Gauër 8631ed65febSNathan Gauër // This node's parent is the root. Meaning this is a top-level construct. 8641ed65febSNathan Gauër // There can be multiple exists, but all are guaranteed to exit at most 1 8651ed65febSNathan Gauër // construct since we are at first level. 8661ed65febSNathan Gauër if (Node->Parent->Header == nullptr) 8671ed65febSNathan Gauër return Modified; 8681ed65febSNathan Gauër 8691ed65febSNathan Gauër // Health check for the structure. 8701ed65febSNathan Gauër assert(Node->Header && Node->Merge); 8711ed65febSNathan Gauër assert(Node->Parent->Header && Node->Parent->Merge); 8721ed65febSNathan Gauër 8731ed65febSNathan Gauër BlockSet ConstructBlocks = getConstructBlocks(S, Node); 8741ed65febSNathan Gauër auto Edges = getExitsFrom(ConstructBlocks, *Node->Header); 8751ed65febSNathan Gauër 8761ed65febSNathan Gauër // No edges exiting the construct. 8771ed65febSNathan Gauër if (Edges.size() < 1) 8781ed65febSNathan Gauër return Modified; 8791ed65febSNathan Gauër 8801ed65febSNathan Gauër bool HasBadEdge = Node->Merge == Node->Parent->Merge || 8811ed65febSNathan Gauër Node->Merge == Node->Parent->Continue; 8821ed65febSNathan Gauër // BasicBlock *Target = Edges[0].second; 8831ed65febSNathan Gauër for (auto &[Src, Dst] : Edges) { 8841ed65febSNathan Gauër // - Breaking from a selection construct: S is a selection construct, S is 8851ed65febSNathan Gauër // the innermost structured 8861ed65febSNathan Gauër // control-flow construct containing A, and B is the merge block for S 8871ed65febSNathan Gauër // - Breaking from the innermost loop: S is the innermost loop construct 8881ed65febSNathan Gauër // containing A, 8891ed65febSNathan Gauër // and B is the merge block for S 8901ed65febSNathan Gauër if (Node->Merge == Dst) 8911ed65febSNathan Gauër continue; 8921ed65febSNathan Gauër 8931ed65febSNathan Gauër // Entering the innermost loop’s continue construct: S is the innermost 8941ed65febSNathan Gauër // loop construct containing A, and B is the continue target for S 8951ed65febSNathan Gauër if (Node->Continue == Dst) 8961ed65febSNathan Gauër continue; 8971ed65febSNathan Gauër 8981ed65febSNathan Gauër // TODO: what about cases branching to another case in the switch? Seems 8991ed65febSNathan Gauër // to work, but need to double check. 9001ed65febSNathan Gauër HasBadEdge = true; 9011ed65febSNathan Gauër } 9021ed65febSNathan Gauër 9031ed65febSNathan Gauër if (!HasBadEdge) 9041ed65febSNathan Gauër return Modified; 9051ed65febSNathan Gauër 9061ed65febSNathan Gauër // Create a single exit node gathering all exit edges. 9071ed65febSNathan Gauër BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges); 9081ed65febSNathan Gauër 9091ed65febSNathan Gauër // Fixup this construct's merge node to point to the new exit. 9101ed65febSNathan Gauër // Note: this algorithm fixes inner-most divergence construct first. So 9111ed65febSNathan Gauër // recursive structures sharing a single merge node are fixed from the 9121ed65febSNathan Gauër // inside toward the outside. 9131ed65febSNathan Gauër auto MergeInstructions = getMergeInstructions(*Node->Header); 9141ed65febSNathan Gauër assert(MergeInstructions.size() == 1); 9151ed65febSNathan Gauër Instruction *I = MergeInstructions[0]; 9161ed65febSNathan Gauër BlockAddress *BA = cast<BlockAddress>(I->getOperand(0)); 9171ed65febSNathan Gauër if (BA->getBasicBlock() == Node->Merge) { 9181ed65febSNathan Gauër auto MergeAddress = BlockAddress::get(NewExit->getParent(), NewExit); 9191ed65febSNathan Gauër I->setOperand(0, MergeAddress); 9201ed65febSNathan Gauër } 9211ed65febSNathan Gauër 9221ed65febSNathan Gauër // Clean up of the possible dangling BockAddr operands to prevent MIR 9231ed65febSNathan Gauër // comments about "address of removed block taken". 9241ed65febSNathan Gauër if (!BA->isConstantUsed()) 9251ed65febSNathan Gauër BA->destroyConstant(); 9261ed65febSNathan Gauër 9271ed65febSNathan Gauër Node->Merge = NewExit; 9281ed65febSNathan Gauër // Regenerate the dom trees. 9291ed65febSNathan Gauër S.invalidate(); 9301ed65febSNathan Gauër return true; 9311ed65febSNathan Gauër } 9321ed65febSNathan Gauër 9331ed65febSNathan Gauër bool splitCriticalEdges(Function &F) { 9341ed65febSNathan Gauër LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 9351ed65febSNathan Gauër Splitter S(F, LI); 9361ed65febSNathan Gauër 9371ed65febSNathan Gauër DivergentConstruct Root; 9381ed65febSNathan Gauër BlockSet Visited; 9391ed65febSNathan Gauër constructDivergentConstruct(Visited, S, &*F.begin(), &Root); 9401ed65febSNathan Gauër return fixupConstruct(S, &Root); 9411ed65febSNathan Gauër } 9421ed65febSNathan Gauër 9431ed65febSNathan Gauër // Simplify branches when possible: 9441ed65febSNathan Gauër // - if the 2 sides of a conditional branch are the same, transforms it to an 9451ed65febSNathan Gauër // unconditional branch. 9461ed65febSNathan Gauër // - if a switch has only 2 distinct successors, converts it to a conditional 9471ed65febSNathan Gauër // branch. 9481ed65febSNathan Gauër bool simplifyBranches(Function &F) { 9491ed65febSNathan Gauër bool Modified = false; 9501ed65febSNathan Gauër 9511ed65febSNathan Gauër for (BasicBlock &BB : F) { 9521ed65febSNathan Gauër SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator()); 9531ed65febSNathan Gauër if (!SI) 9541ed65febSNathan Gauër continue; 9551ed65febSNathan Gauër if (SI->getNumCases() > 1) 9561ed65febSNathan Gauër continue; 9571ed65febSNathan Gauër 9581ed65febSNathan Gauër Modified = true; 9591ed65febSNathan Gauër IRBuilder<> Builder(&BB); 9601ed65febSNathan Gauër Builder.SetInsertPoint(SI); 9611ed65febSNathan Gauër 9621ed65febSNathan Gauër if (SI->getNumCases() == 0) { 9631ed65febSNathan Gauër Builder.CreateBr(SI->getDefaultDest()); 9641ed65febSNathan Gauër } else { 9651ed65febSNathan Gauër Value *Condition = 9661ed65febSNathan Gauër Builder.CreateCmp(CmpInst::ICMP_EQ, SI->getCondition(), 9671ed65febSNathan Gauër SI->case_begin()->getCaseValue()); 9681ed65febSNathan Gauër Builder.CreateCondBr(Condition, SI->case_begin()->getCaseSuccessor(), 9691ed65febSNathan Gauër SI->getDefaultDest()); 9701ed65febSNathan Gauër } 9711ed65febSNathan Gauër SI->eraseFromParent(); 9721ed65febSNathan Gauër } 9731ed65febSNathan Gauër 9741ed65febSNathan Gauër return Modified; 9751ed65febSNathan Gauër } 9761ed65febSNathan Gauër 9771ed65febSNathan Gauër // Makes sure every case target in |F| is unique. If 2 cases branch to the 9781ed65febSNathan Gauër // same basic block, one of the targets is updated so it jumps to a new basic 9791ed65febSNathan Gauër // block ending with a single unconditional branch to the original target. 9801ed65febSNathan Gauër bool splitSwitchCases(Function &F) { 9811ed65febSNathan Gauër bool Modified = false; 9821ed65febSNathan Gauër 9831ed65febSNathan Gauër for (BasicBlock &BB : F) { 9841ed65febSNathan Gauër SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator()); 9851ed65febSNathan Gauër if (!SI) 9861ed65febSNathan Gauër continue; 9871ed65febSNathan Gauër 9881ed65febSNathan Gauër BlockSet Seen; 9891ed65febSNathan Gauër Seen.insert(SI->getDefaultDest()); 9901ed65febSNathan Gauër 9911ed65febSNathan Gauër auto It = SI->case_begin(); 9921ed65febSNathan Gauër while (It != SI->case_end()) { 9931ed65febSNathan Gauër BasicBlock *Target = It->getCaseSuccessor(); 9941ed65febSNathan Gauër if (Seen.count(Target) == 0) { 9951ed65febSNathan Gauër Seen.insert(Target); 9961ed65febSNathan Gauër ++It; 9971ed65febSNathan Gauër continue; 9981ed65febSNathan Gauër } 9991ed65febSNathan Gauër 10001ed65febSNathan Gauër Modified = true; 10011ed65febSNathan Gauër BasicBlock *NewTarget = 10021ed65febSNathan Gauër BasicBlock::Create(F.getContext(), "new.sw.case", &F); 10031ed65febSNathan Gauër IRBuilder<> Builder(NewTarget); 10041ed65febSNathan Gauër Builder.CreateBr(Target); 10051ed65febSNathan Gauër SI->addCase(It->getCaseValue(), NewTarget); 10061ed65febSNathan Gauër It = SI->removeCase(It); 10071ed65febSNathan Gauër } 10081ed65febSNathan Gauër } 10091ed65febSNathan Gauër 10101ed65febSNathan Gauër return Modified; 10111ed65febSNathan Gauër } 10121ed65febSNathan Gauër 1013cba70550SNathan Gauër // Removes blocks not contributing to any structured CFG. This assumes there 1014cba70550SNathan Gauër // is no PHI nodes. 10151ed65febSNathan Gauër bool removeUselessBlocks(Function &F) { 10161ed65febSNathan Gauër std::vector<BasicBlock *> ToRemove; 10171ed65febSNathan Gauër 10181ed65febSNathan Gauër auto MergeBlocks = getMergeBlocks(F); 10191ed65febSNathan Gauër auto ContinueBlocks = getContinueBlocks(F); 10201ed65febSNathan Gauër 10211ed65febSNathan Gauër for (BasicBlock &BB : F) { 10221ed65febSNathan Gauër if (BB.size() != 1) 10231ed65febSNathan Gauër continue; 10241ed65febSNathan Gauër 10251ed65febSNathan Gauër if (isa<ReturnInst>(BB.getTerminator())) 10261ed65febSNathan Gauër continue; 10271ed65febSNathan Gauër 10281ed65febSNathan Gauër if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0) 10291ed65febSNathan Gauër continue; 10301ed65febSNathan Gauër 10311ed65febSNathan Gauër if (BB.getUniqueSuccessor() == nullptr) 10321ed65febSNathan Gauër continue; 10331ed65febSNathan Gauër 10341ed65febSNathan Gauër BasicBlock *Successor = BB.getUniqueSuccessor(); 10351ed65febSNathan Gauër std::vector<BasicBlock *> Predecessors(predecessors(&BB).begin(), 10361ed65febSNathan Gauër predecessors(&BB).end()); 10371ed65febSNathan Gauër for (BasicBlock *Predecessor : Predecessors) 10381ed65febSNathan Gauër replaceBranchTargets(Predecessor, &BB, Successor); 10391ed65febSNathan Gauër ToRemove.push_back(&BB); 10401ed65febSNathan Gauër } 10411ed65febSNathan Gauër 10421ed65febSNathan Gauër for (BasicBlock *BB : ToRemove) 10431ed65febSNathan Gauër BB->eraseFromParent(); 10441ed65febSNathan Gauër 10451ed65febSNathan Gauër return ToRemove.size() != 0; 10461ed65febSNathan Gauër } 10471ed65febSNathan Gauër 10481ed65febSNathan Gauër bool addHeaderToRemainingDivergentDAG(Function &F) { 10491ed65febSNathan Gauër bool Modified = false; 10501ed65febSNathan Gauër 10511ed65febSNathan Gauër auto MergeBlocks = getMergeBlocks(F); 10521ed65febSNathan Gauër auto ContinueBlocks = getContinueBlocks(F); 10531ed65febSNathan Gauër auto HeaderBlocks = getHeaderBlocks(F); 10541ed65febSNathan Gauër 10551ed65febSNathan Gauër DomTreeBuilder::BBDomTree DT; 10561ed65febSNathan Gauër DomTreeBuilder::BBPostDomTree PDT; 10571ed65febSNathan Gauër PDT.recalculate(F); 10581ed65febSNathan Gauër DT.recalculate(F); 10591ed65febSNathan Gauër 10601ed65febSNathan Gauër for (BasicBlock &BB : F) { 10611ed65febSNathan Gauër if (HeaderBlocks.count(&BB) != 0) 10621ed65febSNathan Gauër continue; 10631ed65febSNathan Gauër if (succ_size(&BB) < 2) 10641ed65febSNathan Gauër continue; 10651ed65febSNathan Gauër 10661ed65febSNathan Gauër size_t CandidateEdges = 0; 10671ed65febSNathan Gauër for (BasicBlock *Successor : successors(&BB)) { 10681ed65febSNathan Gauër if (MergeBlocks.count(Successor) != 0 || 10691ed65febSNathan Gauër ContinueBlocks.count(Successor) != 0) 10701ed65febSNathan Gauër continue; 10711ed65febSNathan Gauër if (HeaderBlocks.count(Successor) != 0) 10721ed65febSNathan Gauër continue; 10731ed65febSNathan Gauër CandidateEdges += 1; 10741ed65febSNathan Gauër } 10751ed65febSNathan Gauër 10761ed65febSNathan Gauër if (CandidateEdges <= 1) 10771ed65febSNathan Gauër continue; 10781ed65febSNathan Gauër 10791ed65febSNathan Gauër BasicBlock *Header = &BB; 10801ed65febSNathan Gauër BasicBlock *Merge = PDT.getNode(&BB)->getIDom()->getBlock(); 10811ed65febSNathan Gauër 10821ed65febSNathan Gauër bool HasBadBlock = false; 10831ed65febSNathan Gauër visit(*Header, [&](const BasicBlock *Node) { 10841ed65febSNathan Gauër if (DT.dominates(Header, Node)) 10851ed65febSNathan Gauër return false; 10861ed65febSNathan Gauër if (PDT.dominates(Merge, Node)) 10871ed65febSNathan Gauër return false; 10881ed65febSNathan Gauër if (Node == Header || Node == Merge) 10891ed65febSNathan Gauër return true; 10901ed65febSNathan Gauër 10911ed65febSNathan Gauër HasBadBlock |= MergeBlocks.count(Node) != 0 || 10921ed65febSNathan Gauër ContinueBlocks.count(Node) != 0 || 10931ed65febSNathan Gauër HeaderBlocks.count(Node) != 0; 10941ed65febSNathan Gauër return !HasBadBlock; 10951ed65febSNathan Gauër }); 10961ed65febSNathan Gauër 10971ed65febSNathan Gauër if (HasBadBlock) 10981ed65febSNathan Gauër continue; 10991ed65febSNathan Gauër 11001ed65febSNathan Gauër Modified = true; 1101cba70550SNathan Gauër 1102cba70550SNathan Gauër if (Merge == nullptr) { 1103cba70550SNathan Gauër Merge = *successors(Header).begin(); 1104cba70550SNathan Gauër IRBuilder<> Builder(Header); 1105cba70550SNathan Gauër Builder.SetInsertPoint(Header->getTerminator()); 1106cba70550SNathan Gauër 1107cba70550SNathan Gauër auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge); 1108380bb51bSjoaosaffran createOpSelectMerge(&Builder, MergeAddress); 1109cba70550SNathan Gauër continue; 1110cba70550SNathan Gauër } 1111cba70550SNathan Gauër 11121ed65febSNathan Gauër Instruction *SplitInstruction = Merge->getTerminator(); 11131ed65febSNathan Gauër if (isMergeInstruction(SplitInstruction->getPrevNode())) 11141ed65febSNathan Gauër SplitInstruction = SplitInstruction->getPrevNode(); 11151ed65febSNathan Gauër BasicBlock *NewMerge = 11161ed65febSNathan Gauër Merge->splitBasicBlockBefore(SplitInstruction, "new.merge"); 11171ed65febSNathan Gauër 11181ed65febSNathan Gauër IRBuilder<> Builder(Header); 11191ed65febSNathan Gauër Builder.SetInsertPoint(Header->getTerminator()); 11201ed65febSNathan Gauër 11211ed65febSNathan Gauër auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge); 1122380bb51bSjoaosaffran createOpSelectMerge(&Builder, MergeAddress); 11231ed65febSNathan Gauër } 11241ed65febSNathan Gauër 11251ed65febSNathan Gauër return Modified; 11261ed65febSNathan Gauër } 11271ed65febSNathan Gauër 11281ed65febSNathan Gauër public: 11291ed65febSNathan Gauër static char ID; 11301ed65febSNathan Gauër 11311ed65febSNathan Gauër SPIRVStructurizer() : FunctionPass(ID) { 11321ed65febSNathan Gauër initializeSPIRVStructurizerPass(*PassRegistry::getPassRegistry()); 11331ed65febSNathan Gauër }; 11341ed65febSNathan Gauër 11351ed65febSNathan Gauër virtual bool runOnFunction(Function &F) override { 11361ed65febSNathan Gauër bool Modified = false; 11371ed65febSNathan Gauër 11381ed65febSNathan Gauër // In LLVM, Switches are allowed to have several cases branching to the same 11391ed65febSNathan Gauër // basic block. This is allowed in SPIR-V, but can make structurizing SPIR-V 11401ed65febSNathan Gauër // harder, so first remove edge cases. 11411ed65febSNathan Gauër Modified |= splitSwitchCases(F); 11421ed65febSNathan Gauër 11431ed65febSNathan Gauër // LLVM allows conditional branches to have both side jumping to the same 11441ed65febSNathan Gauër // block. It also allows switched to have a single default, or just one 11451ed65febSNathan Gauër // case. Cleaning this up now. 11461ed65febSNathan Gauër Modified |= simplifyBranches(F); 11471ed65febSNathan Gauër 11481ed65febSNathan Gauër // At this state, we should have a reducible CFG with cycles. 11491ed65febSNathan Gauër // STEP 1: Adding OpLoopMerge instructions to loop headers. 11501ed65febSNathan Gauër Modified |= addMergeForLoops(F); 11511ed65febSNathan Gauër 11521ed65febSNathan Gauër // STEP 2: adding OpSelectionMerge to each node with an in-degree >= 2. 11531ed65febSNathan Gauër Modified |= addMergeForNodesWithMultiplePredecessors(F); 11541ed65febSNathan Gauër 11551ed65febSNathan Gauër // STEP 3: 11561ed65febSNathan Gauër // Sort selection merge, the largest construct goes first. 11571ed65febSNathan Gauër // This simplifies the next step. 11581ed65febSNathan Gauër Modified |= sortSelectionMergeHeaders(F); 11591ed65febSNathan Gauër 11601ed65febSNathan Gauër // STEP 4: As this stage, we can have a single basic block with multiple 11611ed65febSNathan Gauër // OpLoopMerge/OpSelectionMerge instructions. Splitting this block so each 11621ed65febSNathan Gauër // BB has a single merge instruction. 11631ed65febSNathan Gauër Modified |= splitBlocksWithMultipleHeaders(F); 11641ed65febSNathan Gauër 11651ed65febSNathan Gauër // STEP 5: In the previous steps, we added merge blocks the loops and 11661ed65febSNathan Gauër // natural merge blocks (in-degree >= 2). What remains are conditions with 11671ed65febSNathan Gauër // an exiting branch (return, unreachable). In such case, we must start from 11681ed65febSNathan Gauër // the header, and add headers to divergent construct with no headers. 11691ed65febSNathan Gauër Modified |= addMergeForDivergentBlocks(F); 11701ed65febSNathan Gauër 11711ed65febSNathan Gauër // STEP 6: At this stage, we have several divergent construct defines by a 11721ed65febSNathan Gauër // header and a merge block. But their boundaries have no constraints: a 11731ed65febSNathan Gauër // construct exit could be outside of the parents' construct exit. Such 11741ed65febSNathan Gauër // edges are called critical edges. What we need is to split those edges 11751ed65febSNathan Gauër // into several parts. Each part exiting the parent's construct by its merge 11761ed65febSNathan Gauër // block. 11771ed65febSNathan Gauër Modified |= splitCriticalEdges(F); 11781ed65febSNathan Gauër 11791ed65febSNathan Gauër // STEP 7: The previous steps possibly created a lot of "proxy" blocks. 11801ed65febSNathan Gauër // Blocks with a single unconditional branch, used to create a valid 11811ed65febSNathan Gauër // divergent construct tree. Some nodes are still requires (e.g: nodes 11821ed65febSNathan Gauër // allowing a valid exit through the parent's merge block). But some are 11831ed65febSNathan Gauër // left-overs of past transformations, and could cause actual validation 11841ed65febSNathan Gauër // issues. E.g: the SPIR-V spec allows a construct to break to the parents 11851ed65febSNathan Gauër // loop construct without an OpSelectionMerge, but this requires a straight 11861ed65febSNathan Gauër // jump. If a proxy block lies between the conditional branch and the 11871ed65febSNathan Gauër // parent's merge, the CFG is not valid. 11881ed65febSNathan Gauër Modified |= removeUselessBlocks(F); 11891ed65febSNathan Gauër 11901ed65febSNathan Gauër // STEP 8: Final fix-up steps: our tree boundaries are correct, but some 11911ed65febSNathan Gauër // blocks are branching with no header. Those are often simple conditional 11921ed65febSNathan Gauër // branches with 1 or 2 returning edges. Adding a header for those. 11931ed65febSNathan Gauër Modified |= addHeaderToRemainingDivergentDAG(F); 11941ed65febSNathan Gauër 11951ed65febSNathan Gauër // STEP 9: sort basic blocks to match both the LLVM & SPIR-V requirements. 11961ed65febSNathan Gauër Modified |= sortBlocks(F); 11971ed65febSNathan Gauër 11981ed65febSNathan Gauër return Modified; 11991ed65febSNathan Gauër } 12001ed65febSNathan Gauër 12011ed65febSNathan Gauër void getAnalysisUsage(AnalysisUsage &AU) const override { 12021ed65febSNathan Gauër AU.addRequired<DominatorTreeWrapperPass>(); 12031ed65febSNathan Gauër AU.addRequired<LoopInfoWrapperPass>(); 12041ed65febSNathan Gauër AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>(); 12051ed65febSNathan Gauër 12061ed65febSNathan Gauër AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>(); 12071ed65febSNathan Gauër FunctionPass::getAnalysisUsage(AU); 12081ed65febSNathan Gauër } 1209380bb51bSjoaosaffran 1210380bb51bSjoaosaffran void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) { 1211380bb51bSjoaosaffran Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator(); 1212380bb51bSjoaosaffran 1213380bb51bSjoaosaffran MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); 1214380bb51bSjoaosaffran 1215380bb51bSjoaosaffran ConstantInt *BranchHint = llvm::ConstantInt::get(Builder->getInt32Ty(), 0); 1216380bb51bSjoaosaffran 1217380bb51bSjoaosaffran if (MDNode) { 1218380bb51bSjoaosaffran assert(MDNode->getNumOperands() == 2 && 1219380bb51bSjoaosaffran "invalid metadata hlsl.controlflow.hint"); 1220380bb51bSjoaosaffran BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1)); 1221380bb51bSjoaosaffran 1222380bb51bSjoaosaffran assert(BranchHint && "invalid metadata value for hlsl.controlflow.hint"); 1223380bb51bSjoaosaffran } 1224380bb51bSjoaosaffran 1225380bb51bSjoaosaffran llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, BranchHint}; 1226380bb51bSjoaosaffran 1227380bb51bSjoaosaffran Builder->CreateIntrinsic(Intrinsic::spv_selection_merge, 1228380bb51bSjoaosaffran {MergeAddress->getType()}, {Args}); 1229380bb51bSjoaosaffran } 12301ed65febSNathan Gauër }; 12311ed65febSNathan Gauër } // namespace llvm 12321ed65febSNathan Gauër 12331ed65febSNathan Gauër char SPIRVStructurizer::ID = 0; 12341ed65febSNathan Gauër 123510b1caf6Sjoaosaffran INITIALIZE_PASS_BEGIN(SPIRVStructurizer, "spirv-structurizer", 123610b1caf6Sjoaosaffran "structurize SPIRV", false, false) 12371ed65febSNathan Gauër INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 12381ed65febSNathan Gauër INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 12391ed65febSNathan Gauër INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 12401ed65febSNathan Gauër INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) 12411ed65febSNathan Gauër 124210b1caf6Sjoaosaffran INITIALIZE_PASS_END(SPIRVStructurizer, "spirv-structurizer", 124310b1caf6Sjoaosaffran "structurize SPIRV", false, false) 12441ed65febSNathan Gauër 12451ed65febSNathan Gauër FunctionPass *llvm::createSPIRVStructurizerPass() { 12461ed65febSNathan Gauër return new SPIRVStructurizer(); 12471ed65febSNathan Gauër } 124810b1caf6Sjoaosaffran 124910b1caf6Sjoaosaffran PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F, 125010b1caf6Sjoaosaffran FunctionAnalysisManager &AF) { 1251380bb51bSjoaosaffran 1252380bb51bSjoaosaffran auto FPM = legacy::FunctionPassManager(F.getParent()); 1253380bb51bSjoaosaffran FPM.add(createSPIRVStructurizerPass()); 1254380bb51bSjoaosaffran 1255380bb51bSjoaosaffran if (!FPM.run(F)) 125610b1caf6Sjoaosaffran return PreservedAnalyses::all(); 125710b1caf6Sjoaosaffran PreservedAnalyses PA; 125810b1caf6Sjoaosaffran PA.preserveSet<CFGAnalyses>(); 125910b1caf6Sjoaosaffran return PA; 126010b1caf6Sjoaosaffran } 1261