xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp (revision 3da5e82e31712792411945b655929a1680fb476c)
1 //===-- NVPTXLowerUnreachable.cpp - Lower unreachables to exit =====--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // PTX does not have a notion of `unreachable`, which results in emitted basic
10 // blocks having an edge to the next block:
11 //
12 //   block1:
13 //     call @does_not_return();
14 //     // unreachable
15 //   block2:
16 //     // ptxas will create a CFG edge from block1 to block2
17 //
18 // This may result in significant changes to the control flow graph, e.g., when
19 // LLVM moves unreachable blocks to the end of the function. That's a problem
20 // in the context of divergent control flow, as `ptxas` uses the CFG to
21 // determine divergent regions, and some intructions may not be executed
22 // divergently.
23 //
24 // For example, `bar.sync` is not allowed to be executed divergently on Pascal
25 // or earlier. If we start with the following:
26 //
27 //   entry:
28 //     // start of divergent region
29 //     @%p0 bra cont;
30 //     @%p1 bra unlikely;
31 //     ...
32 //     bra.uni cont;
33 //   unlikely:
34 //     ...
35 //     // unreachable
36 //   cont:
37 //     // end of divergent region
38 //     bar.sync 0;
39 //     bra.uni exit;
40 //   exit:
41 //     ret;
42 //
43 // it is transformed by the branch-folder and block-placement passes to:
44 //
45 //   entry:
46 //     // start of divergent region
47 //     @%p0 bra cont;
48 //     @%p1 bra unlikely;
49 //     ...
50 //     bra.uni cont;
51 //   cont:
52 //     bar.sync 0;
53 //     bra.uni exit;
54 //   unlikely:
55 //     ...
56 //     // unreachable
57 //   exit:
58 //     // end of divergent region
59 //     ret;
60 //
61 // After moving the `unlikely` block to the end of the function, it has an edge
62 // to the `exit` block, which widens the divergent region and makes the
63 // `bar.sync` instruction happen divergently.
64 //
65 // To work around this, we add an `exit` instruction before every `unreachable`,
66 // as `ptxas` understands that exit terminates the CFG. We do only do this if
67 // `unreachable` is not lowered to `trap`, which has the same effect (although
68 // with current versions of `ptxas` only because it is emited as `trap; exit;`).
69 //
70 //===----------------------------------------------------------------------===//
71 
72 #include "NVPTX.h"
73 #include "llvm/IR/Function.h"
74 #include "llvm/IR/InlineAsm.h"
75 #include "llvm/IR/Instructions.h"
76 #include "llvm/IR/Type.h"
77 #include "llvm/Pass.h"
78 
79 using namespace llvm;
80 
81 namespace llvm {
82 void initializeNVPTXLowerUnreachablePass(PassRegistry &);
83 }
84 
85 namespace {
86 class NVPTXLowerUnreachable : public FunctionPass {
87   StringRef getPassName() const override;
88   bool runOnFunction(Function &F) override;
89   bool isLoweredToTrap(const UnreachableInst &I) const;
90 
91 public:
92   static char ID; // Pass identification, replacement for typeid
93   NVPTXLowerUnreachable(bool TrapUnreachable, bool NoTrapAfterNoreturn)
94       : FunctionPass(ID), TrapUnreachable(TrapUnreachable),
95         NoTrapAfterNoreturn(NoTrapAfterNoreturn) {}
96 
97 private:
98   bool TrapUnreachable;
99   bool NoTrapAfterNoreturn;
100 };
101 } // namespace
102 
103 char NVPTXLowerUnreachable::ID = 1;
104 
105 INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
106                 "Lower Unreachable", false, false)
107 
108 StringRef NVPTXLowerUnreachable::getPassName() const {
109   return "add an exit instruction before every unreachable";
110 }
111 
112 // =============================================================================
113 // Returns whether a `trap` intrinsic would be emitted before I.
114 //
115 // This is a copy of the logic in SelectionDAGBuilder::visitUnreachable().
116 // =============================================================================
117 bool NVPTXLowerUnreachable::isLoweredToTrap(const UnreachableInst &I) const {
118   if (const auto *Call = dyn_cast_or_null<CallInst>(I.getPrevNode())) {
119     // We've already emitted a non-continuable trap.
120     if (Call->isNonContinuableTrap())
121       return true;
122 
123     // No traps are emitted for calls that do not return
124     // when this option is enabled.
125     if (NoTrapAfterNoreturn && Call->doesNotReturn())
126       return false;
127   }
128 
129   // In all other cases, we will generate a trap if TrapUnreachable is set.
130   return TrapUnreachable;
131 }
132 
133 // =============================================================================
134 // Main function for this pass.
135 // =============================================================================
136 bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
137   if (skipFunction(F))
138     return false;
139   // Early out iff isLoweredToTrap() always returns true.
140   if (TrapUnreachable && !NoTrapAfterNoreturn)
141     return false;
142 
143   LLVMContext &C = F.getContext();
144   FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false);
145   InlineAsm *Exit = InlineAsm::get(ExitFTy, "exit;", "", true);
146 
147   bool Changed = false;
148   for (auto &BB : F)
149     for (auto &I : BB) {
150       if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) {
151         if (isLoweredToTrap(*unreachableInst))
152           continue; // trap is emitted as `trap; exit;`.
153         CallInst::Create(ExitFTy, Exit, "", unreachableInst->getIterator());
154         Changed = true;
155       }
156     }
157   return Changed;
158 }
159 
160 FunctionPass *llvm::createNVPTXLowerUnreachablePass(bool TrapUnreachable,
161                                                     bool NoTrapAfterNoreturn) {
162   return new NVPTXLowerUnreachable(TrapUnreachable, NoTrapAfterNoreturn);
163 }
164