xref: /llvm-project/llvm/unittests/Target/SPIRV/SPIRVPartialOrderingVisitorTests.cpp (revision 45b567be8d0d430c786c41f826d192fadf863bb8)
1*45b567beSNathan Gauër //===- SPIRVPartialOrderingVisitorTests.cpp ----------------------------===//
2*45b567beSNathan Gauër //
3*45b567beSNathan Gauër // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*45b567beSNathan Gauër // See https://llvm.org/LICENSE.txt for license information.
5*45b567beSNathan Gauër // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*45b567beSNathan Gauër //
7*45b567beSNathan Gauër //===----------------------------------------------------------------------===//
8*45b567beSNathan Gauër 
9*45b567beSNathan Gauër #include "SPIRVUtils.h"
10*45b567beSNathan Gauër #include "llvm/Analysis/DominanceFrontier.h"
11*45b567beSNathan Gauër #include "llvm/Analysis/PostDominators.h"
12*45b567beSNathan Gauër #include "llvm/AsmParser/Parser.h"
13*45b567beSNathan Gauër #include "llvm/IR/Instructions.h"
14*45b567beSNathan Gauër #include "llvm/IR/LLVMContext.h"
15*45b567beSNathan Gauër #include "llvm/IR/LegacyPassManager.h"
16*45b567beSNathan Gauër #include "llvm/IR/Module.h"
17*45b567beSNathan Gauër #include "llvm/IR/PassInstrumentation.h"
18*45b567beSNathan Gauër #include "llvm/IR/Type.h"
19*45b567beSNathan Gauër #include "llvm/IR/TypedPointerType.h"
20*45b567beSNathan Gauër #include "llvm/Support/SourceMgr.h"
21*45b567beSNathan Gauër 
22*45b567beSNathan Gauër #include "gmock/gmock.h"
23*45b567beSNathan Gauër #include "gtest/gtest.h"
24*45b567beSNathan Gauër #include <queue>
25*45b567beSNathan Gauër 
26*45b567beSNathan Gauër using namespace llvm;
27*45b567beSNathan Gauër using namespace llvm::SPIRV;
28*45b567beSNathan Gauër 
29*45b567beSNathan Gauër class SPIRVPartialOrderingVisitorTest : public testing::Test {
30*45b567beSNathan Gauër protected:
31*45b567beSNathan Gauër   void TearDown() override { M.reset(); }
32*45b567beSNathan Gauër 
33*45b567beSNathan Gauër   void run(StringRef Assembly) {
34*45b567beSNathan Gauër     assert(M == nullptr &&
35*45b567beSNathan Gauër            "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
36*45b567beSNathan Gauër 
37*45b567beSNathan Gauër     SMDiagnostic Error;
38*45b567beSNathan Gauër     M = parseAssemblyString(Assembly, Error, Context);
39*45b567beSNathan Gauër     assert(M && "Bad assembly. Bad test?");
40*45b567beSNathan Gauër 
41*45b567beSNathan Gauër     llvm::Function *F = M->getFunction("main");
42*45b567beSNathan Gauër     Visitor = std::make_unique<PartialOrderingVisitor>(*F);
43*45b567beSNathan Gauër   }
44*45b567beSNathan Gauër 
45*45b567beSNathan Gauër   void
46*45b567beSNathan Gauër   checkBasicBlockRank(std::vector<std::pair<const char *, size_t>> &&Expected) {
47*45b567beSNathan Gauër     llvm::Function *F = M->getFunction("main");
48*45b567beSNathan Gauër     auto It = Expected.begin();
49*45b567beSNathan Gauër     Visitor->partialOrderVisit(*F->begin(), [&](BasicBlock *BB) {
50*45b567beSNathan Gauër       const auto &[Name, Rank] = *It;
51*45b567beSNathan Gauër       EXPECT_TRUE(It != Expected.end())
52*45b567beSNathan Gauër           << "Unexpected block \"" << BB->getName() << " visited.";
53*45b567beSNathan Gauër       EXPECT_TRUE(BB->getName() == Name)
54*45b567beSNathan Gauër           << "Error: expected block \"" << Name << "\" got \"" << BB->getName()
55*45b567beSNathan Gauër           << "\"";
56*45b567beSNathan Gauër       EXPECT_EQ(Rank, Visitor->GetNodeRank(BB))
57*45b567beSNathan Gauër           << "Bad rank for BB \"" << BB->getName() << "\"";
58*45b567beSNathan Gauër       It++;
59*45b567beSNathan Gauër       return true;
60*45b567beSNathan Gauër     });
61*45b567beSNathan Gauër     ASSERT_TRUE(It == Expected.end())
62*45b567beSNathan Gauër         << "Expected block \"" << It->first
63*45b567beSNathan Gauër         << "\" but reached the end of the function instead.";
64*45b567beSNathan Gauër   }
65*45b567beSNathan Gauër 
66*45b567beSNathan Gauër protected:
67*45b567beSNathan Gauër   LLVMContext Context;
68*45b567beSNathan Gauër   std::unique_ptr<Module> M;
69*45b567beSNathan Gauër   std::unique_ptr<PartialOrderingVisitor> Visitor;
70*45b567beSNathan Gauër };
71*45b567beSNathan Gauër 
72*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, EmptyFunction) {
73*45b567beSNathan Gauër   StringRef Assembly = R"(
74*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
75*45b567beSNathan Gauër       ret void
76*45b567beSNathan Gauër     }
77*45b567beSNathan Gauër   )";
78*45b567beSNathan Gauër 
79*45b567beSNathan Gauër   run(Assembly);
80*45b567beSNathan Gauër   checkBasicBlockRank({{"", 0}});
81*45b567beSNathan Gauër }
82*45b567beSNathan Gauër 
83*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, BasicBlockSwap) {
84*45b567beSNathan Gauër   StringRef Assembly = R"(
85*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
86*45b567beSNathan Gauër     entry:
87*45b567beSNathan Gauër       br label %middle
88*45b567beSNathan Gauër     exit:
89*45b567beSNathan Gauër       ret void
90*45b567beSNathan Gauër     middle:
91*45b567beSNathan Gauër       br label %exit
92*45b567beSNathan Gauër     }
93*45b567beSNathan Gauër   )";
94*45b567beSNathan Gauër 
95*45b567beSNathan Gauër   run(Assembly);
96*45b567beSNathan Gauër   checkBasicBlockRank({{"entry", 0}, {"middle", 1}, {"exit", 2}});
97*45b567beSNathan Gauër }
98*45b567beSNathan Gauër 
99*45b567beSNathan Gauër // Skip condition:
100*45b567beSNathan Gauër //         +-> A -+
101*45b567beSNathan Gauër //  entry -+      +-> C
102*45b567beSNathan Gauër //         +------+
103*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, SkipCondition) {
104*45b567beSNathan Gauër   StringRef Assembly = R"(
105*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
106*45b567beSNathan Gauër     entry:
107*45b567beSNathan Gauër       %1 = icmp ne i32 0, 0
108*45b567beSNathan Gauër       br i1 %1, label %c, label %a
109*45b567beSNathan Gauër     c:
110*45b567beSNathan Gauër       ret void
111*45b567beSNathan Gauër     a:
112*45b567beSNathan Gauër       br label %c
113*45b567beSNathan Gauër     }
114*45b567beSNathan Gauër   )";
115*45b567beSNathan Gauër 
116*45b567beSNathan Gauër   run(Assembly);
117*45b567beSNathan Gauër   checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"c", 2}});
118*45b567beSNathan Gauër }
119*45b567beSNathan Gauër 
120*45b567beSNathan Gauër // Simple loop:
121*45b567beSNathan Gauër // entry -> header <-----------------+
122*45b567beSNathan Gauër //           | `-> body -> continue -+
123*45b567beSNathan Gauër //           `-> end
124*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, LoopOrdering) {
125*45b567beSNathan Gauër   StringRef Assembly = R"(
126*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
127*45b567beSNathan Gauër     entry:
128*45b567beSNathan Gauër       %1 = icmp ne i32 0, 0
129*45b567beSNathan Gauër       br label %header
130*45b567beSNathan Gauër     end:
131*45b567beSNathan Gauër       ret void
132*45b567beSNathan Gauër     body:
133*45b567beSNathan Gauër       br label %continue
134*45b567beSNathan Gauër     continue:
135*45b567beSNathan Gauër       br label %header
136*45b567beSNathan Gauër     header:
137*45b567beSNathan Gauër       br i1 %1, label %body, label %end
138*45b567beSNathan Gauër     }
139*45b567beSNathan Gauër   )";
140*45b567beSNathan Gauër 
141*45b567beSNathan Gauër   run(Assembly);
142*45b567beSNathan Gauër   checkBasicBlockRank(
143*45b567beSNathan Gauër       {{"entry", 0}, {"header", 1}, {"body", 2}, {"continue", 3}, {"end", 4}});
144*45b567beSNathan Gauër }
145*45b567beSNathan Gauër 
146*45b567beSNathan Gauër // Diamond condition:
147*45b567beSNathan Gauër //         +-> A -+
148*45b567beSNathan Gauër //  entry -+      +-> C
149*45b567beSNathan Gauër //         +-> B -+
150*45b567beSNathan Gauër //
151*45b567beSNathan Gauër // A and B order can be flipped with no effect, but it must be remain
152*45b567beSNathan Gauër // deterministic/stable.
153*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, DiamondCondition) {
154*45b567beSNathan Gauër   StringRef Assembly = R"(
155*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
156*45b567beSNathan Gauër     entry:
157*45b567beSNathan Gauër       %1 = icmp ne i32 0, 0
158*45b567beSNathan Gauër       br i1 %1, label %a, label %b
159*45b567beSNathan Gauër     c:
160*45b567beSNathan Gauër       ret void
161*45b567beSNathan Gauër     b:
162*45b567beSNathan Gauër       br label %c
163*45b567beSNathan Gauër     a:
164*45b567beSNathan Gauër       br label %c
165*45b567beSNathan Gauër     }
166*45b567beSNathan Gauër   )";
167*45b567beSNathan Gauër 
168*45b567beSNathan Gauër   run(Assembly);
169*45b567beSNathan Gauër   checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}});
170*45b567beSNathan Gauër }
171*45b567beSNathan Gauër 
172*45b567beSNathan Gauër // Crossing conditions:
173*45b567beSNathan Gauër //             +------+  +-> C -+
174*45b567beSNathan Gauër //         +-> A -+   |  |      |
175*45b567beSNathan Gauër //  entry -+      +--_|_-+      +-> E
176*45b567beSNathan Gauër //         +-> B -+   |         |
177*45b567beSNathan Gauër //             +------+----> D -+
178*45b567beSNathan Gauër //
179*45b567beSNathan Gauër // A & B have the same rank.
180*45b567beSNathan Gauër // C & D have the same rank, but are after A & B.
181*45b567beSNathan Gauër // E if the last block.
182*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, CrossingCondition) {
183*45b567beSNathan Gauër   StringRef Assembly = R"(
184*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
185*45b567beSNathan Gauër     entry:
186*45b567beSNathan Gauër       %1 = icmp ne i32 0, 0
187*45b567beSNathan Gauër       br i1 %1, label %a, label %b
188*45b567beSNathan Gauër     e:
189*45b567beSNathan Gauër       ret void
190*45b567beSNathan Gauër     c:
191*45b567beSNathan Gauër       br label %e
192*45b567beSNathan Gauër     b:
193*45b567beSNathan Gauër       br i1 %1, label %d, label %c
194*45b567beSNathan Gauër     d:
195*45b567beSNathan Gauër       br label %e
196*45b567beSNathan Gauër     a:
197*45b567beSNathan Gauër       br i1 %1, label %c, label %d
198*45b567beSNathan Gauër     }
199*45b567beSNathan Gauër   )";
200*45b567beSNathan Gauër 
201*45b567beSNathan Gauër   run(Assembly);
202*45b567beSNathan Gauër   checkBasicBlockRank(
203*45b567beSNathan Gauër       {{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}, {"d", 2}, {"e", 3}});
204*45b567beSNathan Gauër }
205*45b567beSNathan Gauër 
206*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, LoopDiamond) {
207*45b567beSNathan Gauër   StringRef Assembly = R"(
208*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
209*45b567beSNathan Gauër     entry:
210*45b567beSNathan Gauër       %1 = icmp ne i32 0, 0
211*45b567beSNathan Gauër       br label %header
212*45b567beSNathan Gauër     header:
213*45b567beSNathan Gauër       br i1 %1, label %body, label %end
214*45b567beSNathan Gauër     body:
215*45b567beSNathan Gauër       br i1 %1, label %inside_a, label %break
216*45b567beSNathan Gauër     inside_a:
217*45b567beSNathan Gauër       br label %inside_b
218*45b567beSNathan Gauër     inside_b:
219*45b567beSNathan Gauër       br i1 %1, label %inside_c, label %inside_d
220*45b567beSNathan Gauër     inside_c:
221*45b567beSNathan Gauër       br label %continue
222*45b567beSNathan Gauër     inside_d:
223*45b567beSNathan Gauër       br label %continue
224*45b567beSNathan Gauër     break:
225*45b567beSNathan Gauër       br label %end
226*45b567beSNathan Gauër     continue:
227*45b567beSNathan Gauër       br label %header
228*45b567beSNathan Gauër     end:
229*45b567beSNathan Gauër       ret void
230*45b567beSNathan Gauër     }
231*45b567beSNathan Gauër   )";
232*45b567beSNathan Gauër 
233*45b567beSNathan Gauër   run(Assembly);
234*45b567beSNathan Gauër   checkBasicBlockRank({{"entry", 0},
235*45b567beSNathan Gauër                        {"header", 1},
236*45b567beSNathan Gauër                        {"body", 2},
237*45b567beSNathan Gauër                        {"inside_a", 3},
238*45b567beSNathan Gauër                        {"inside_b", 4},
239*45b567beSNathan Gauër                        {"inside_c", 5},
240*45b567beSNathan Gauër                        {"inside_d", 5},
241*45b567beSNathan Gauër                        {"continue", 6},
242*45b567beSNathan Gauër                        {"break", 7},
243*45b567beSNathan Gauër                        {"end", 8}});
244*45b567beSNathan Gauër }
245*45b567beSNathan Gauër 
246*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, LoopNested) {
247*45b567beSNathan Gauër   StringRef Assembly = R"(
248*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
249*45b567beSNathan Gauër     entry:
250*45b567beSNathan Gauër       %1 = icmp ne i32 0, 0
251*45b567beSNathan Gauër       br label %a
252*45b567beSNathan Gauër     a:
253*45b567beSNathan Gauër       br i1 %1, label %h, label %b
254*45b567beSNathan Gauër     b:
255*45b567beSNathan Gauër       br label %c
256*45b567beSNathan Gauër     c:
257*45b567beSNathan Gauër       br i1 %1, label %d, label %e
258*45b567beSNathan Gauër     d:
259*45b567beSNathan Gauër       br label %g
260*45b567beSNathan Gauër     e:
261*45b567beSNathan Gauër       br label %f
262*45b567beSNathan Gauër     f:
263*45b567beSNathan Gauër       br label %c
264*45b567beSNathan Gauër     g:
265*45b567beSNathan Gauër       br label %a
266*45b567beSNathan Gauër     h:
267*45b567beSNathan Gauër       ret void
268*45b567beSNathan Gauër     }
269*45b567beSNathan Gauër   )";
270*45b567beSNathan Gauër 
271*45b567beSNathan Gauër   run(Assembly);
272*45b567beSNathan Gauër   checkBasicBlockRank({{"entry", 0},
273*45b567beSNathan Gauër                        {"a", 1},
274*45b567beSNathan Gauër                        {"b", 2},
275*45b567beSNathan Gauër                        {"c", 3},
276*45b567beSNathan Gauër                        {"e", 4},
277*45b567beSNathan Gauër                        {"f", 5},
278*45b567beSNathan Gauër                        {"d", 6},
279*45b567beSNathan Gauër                        {"g", 7},
280*45b567beSNathan Gauër                        {"h", 8}});
281*45b567beSNathan Gauër }
282*45b567beSNathan Gauër 
283*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, IfNested) {
284*45b567beSNathan Gauër   StringRef Assembly = R"(
285*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
286*45b567beSNathan Gauër     entry:
287*45b567beSNathan Gauër       br i1 true, label %a, label %d
288*45b567beSNathan Gauër     a:
289*45b567beSNathan Gauër       br i1 true, label %b, label %c
290*45b567beSNathan Gauër     b:
291*45b567beSNathan Gauër       br label %c
292*45b567beSNathan Gauër     c:
293*45b567beSNathan Gauër       br label %j
294*45b567beSNathan Gauër     d:
295*45b567beSNathan Gauër       br i1 true, label %e, label %f
296*45b567beSNathan Gauër     e:
297*45b567beSNathan Gauër       br label %i
298*45b567beSNathan Gauër     f:
299*45b567beSNathan Gauër       br i1 true, label %g, label %h
300*45b567beSNathan Gauër     g:
301*45b567beSNathan Gauër       br label %h
302*45b567beSNathan Gauër     h:
303*45b567beSNathan Gauër       br label %i
304*45b567beSNathan Gauër     i:
305*45b567beSNathan Gauër       br label %j
306*45b567beSNathan Gauër     j:
307*45b567beSNathan Gauër       ret void
308*45b567beSNathan Gauër     }
309*45b567beSNathan Gauër   )";
310*45b567beSNathan Gauër   run(Assembly);
311*45b567beSNathan Gauër   checkBasicBlockRank({{"entry", 0},
312*45b567beSNathan Gauër                        {"a", 1},
313*45b567beSNathan Gauër                        {"d", 1},
314*45b567beSNathan Gauër                        {"b", 2},
315*45b567beSNathan Gauër                        {"e", 2},
316*45b567beSNathan Gauër                        {"f", 2},
317*45b567beSNathan Gauër                        {"c", 3},
318*45b567beSNathan Gauër                        {"g", 3},
319*45b567beSNathan Gauër                        {"h", 4},
320*45b567beSNathan Gauër                        {"i", 5},
321*45b567beSNathan Gauër                        {"j", 6}});
322*45b567beSNathan Gauër }
323*45b567beSNathan Gauër 
324*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, CheckDeathIrreducible) {
325*45b567beSNathan Gauër   StringRef Assembly = R"(
326*45b567beSNathan Gauër     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
327*45b567beSNathan Gauër     entry:
328*45b567beSNathan Gauër       %1 = icmp ne i32 0, 0
329*45b567beSNathan Gauër       br label %a
330*45b567beSNathan Gauër     b:
331*45b567beSNathan Gauër       br i1 %1, label %a, label %c
332*45b567beSNathan Gauër     c:
333*45b567beSNathan Gauër       br label %b
334*45b567beSNathan Gauër     a:
335*45b567beSNathan Gauër       br i1 %1, label %b, label %c
336*45b567beSNathan Gauër     }
337*45b567beSNathan Gauër   )";
338*45b567beSNathan Gauër 
339*45b567beSNathan Gauër   ASSERT_DEATH(
340*45b567beSNathan Gauër       { run(Assembly); },
341*45b567beSNathan Gauër       "No valid candidate in the queue. Is the graph reducible?");
342*45b567beSNathan Gauër }
343