xref: /llvm-project/llvm/unittests/Target/SPIRV/SPIRVPartialOrderingVisitorTests.cpp (revision 45b567be8d0d430c786c41f826d192fadf863bb8)
1 //===- SPIRVPartialOrderingVisitorTests.cpp ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SPIRVUtils.h"
10 #include "llvm/Analysis/DominanceFrontier.h"
11 #include "llvm/Analysis/PostDominators.h"
12 #include "llvm/AsmParser/Parser.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/LLVMContext.h"
15 #include "llvm/IR/LegacyPassManager.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/PassInstrumentation.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/IR/TypedPointerType.h"
20 #include "llvm/Support/SourceMgr.h"
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include <queue>
25 
26 using namespace llvm;
27 using namespace llvm::SPIRV;
28 
29 class SPIRVPartialOrderingVisitorTest : public testing::Test {
30 protected:
31   void TearDown() override { M.reset(); }
32 
33   void run(StringRef Assembly) {
34     assert(M == nullptr &&
35            "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
36 
37     SMDiagnostic Error;
38     M = parseAssemblyString(Assembly, Error, Context);
39     assert(M && "Bad assembly. Bad test?");
40 
41     llvm::Function *F = M->getFunction("main");
42     Visitor = std::make_unique<PartialOrderingVisitor>(*F);
43   }
44 
45   void
46   checkBasicBlockRank(std::vector<std::pair<const char *, size_t>> &&Expected) {
47     llvm::Function *F = M->getFunction("main");
48     auto It = Expected.begin();
49     Visitor->partialOrderVisit(*F->begin(), [&](BasicBlock *BB) {
50       const auto &[Name, Rank] = *It;
51       EXPECT_TRUE(It != Expected.end())
52           << "Unexpected block \"" << BB->getName() << " visited.";
53       EXPECT_TRUE(BB->getName() == Name)
54           << "Error: expected block \"" << Name << "\" got \"" << BB->getName()
55           << "\"";
56       EXPECT_EQ(Rank, Visitor->GetNodeRank(BB))
57           << "Bad rank for BB \"" << BB->getName() << "\"";
58       It++;
59       return true;
60     });
61     ASSERT_TRUE(It == Expected.end())
62         << "Expected block \"" << It->first
63         << "\" but reached the end of the function instead.";
64   }
65 
66 protected:
67   LLVMContext Context;
68   std::unique_ptr<Module> M;
69   std::unique_ptr<PartialOrderingVisitor> Visitor;
70 };
71 
72 TEST_F(SPIRVPartialOrderingVisitorTest, EmptyFunction) {
73   StringRef Assembly = R"(
74     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
75       ret void
76     }
77   )";
78 
79   run(Assembly);
80   checkBasicBlockRank({{"", 0}});
81 }
82 
83 TEST_F(SPIRVPartialOrderingVisitorTest, BasicBlockSwap) {
84   StringRef Assembly = R"(
85     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
86     entry:
87       br label %middle
88     exit:
89       ret void
90     middle:
91       br label %exit
92     }
93   )";
94 
95   run(Assembly);
96   checkBasicBlockRank({{"entry", 0}, {"middle", 1}, {"exit", 2}});
97 }
98 
99 // Skip condition:
100 //         +-> A -+
101 //  entry -+      +-> C
102 //         +------+
103 TEST_F(SPIRVPartialOrderingVisitorTest, SkipCondition) {
104   StringRef Assembly = R"(
105     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
106     entry:
107       %1 = icmp ne i32 0, 0
108       br i1 %1, label %c, label %a
109     c:
110       ret void
111     a:
112       br label %c
113     }
114   )";
115 
116   run(Assembly);
117   checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"c", 2}});
118 }
119 
120 // Simple loop:
121 // entry -> header <-----------------+
122 //           | `-> body -> continue -+
123 //           `-> end
124 TEST_F(SPIRVPartialOrderingVisitorTest, LoopOrdering) {
125   StringRef Assembly = R"(
126     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
127     entry:
128       %1 = icmp ne i32 0, 0
129       br label %header
130     end:
131       ret void
132     body:
133       br label %continue
134     continue:
135       br label %header
136     header:
137       br i1 %1, label %body, label %end
138     }
139   )";
140 
141   run(Assembly);
142   checkBasicBlockRank(
143       {{"entry", 0}, {"header", 1}, {"body", 2}, {"continue", 3}, {"end", 4}});
144 }
145 
146 // Diamond condition:
147 //         +-> A -+
148 //  entry -+      +-> C
149 //         +-> B -+
150 //
151 // A and B order can be flipped with no effect, but it must be remain
152 // deterministic/stable.
153 TEST_F(SPIRVPartialOrderingVisitorTest, DiamondCondition) {
154   StringRef Assembly = R"(
155     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
156     entry:
157       %1 = icmp ne i32 0, 0
158       br i1 %1, label %a, label %b
159     c:
160       ret void
161     b:
162       br label %c
163     a:
164       br label %c
165     }
166   )";
167 
168   run(Assembly);
169   checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}});
170 }
171 
172 // Crossing conditions:
173 //             +------+  +-> C -+
174 //         +-> A -+   |  |      |
175 //  entry -+      +--_|_-+      +-> E
176 //         +-> B -+   |         |
177 //             +------+----> D -+
178 //
179 // A & B have the same rank.
180 // C & D have the same rank, but are after A & B.
181 // E if the last block.
182 TEST_F(SPIRVPartialOrderingVisitorTest, CrossingCondition) {
183   StringRef Assembly = R"(
184     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
185     entry:
186       %1 = icmp ne i32 0, 0
187       br i1 %1, label %a, label %b
188     e:
189       ret void
190     c:
191       br label %e
192     b:
193       br i1 %1, label %d, label %c
194     d:
195       br label %e
196     a:
197       br i1 %1, label %c, label %d
198     }
199   )";
200 
201   run(Assembly);
202   checkBasicBlockRank(
203       {{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}, {"d", 2}, {"e", 3}});
204 }
205 
206 TEST_F(SPIRVPartialOrderingVisitorTest, LoopDiamond) {
207   StringRef Assembly = R"(
208     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
209     entry:
210       %1 = icmp ne i32 0, 0
211       br label %header
212     header:
213       br i1 %1, label %body, label %end
214     body:
215       br i1 %1, label %inside_a, label %break
216     inside_a:
217       br label %inside_b
218     inside_b:
219       br i1 %1, label %inside_c, label %inside_d
220     inside_c:
221       br label %continue
222     inside_d:
223       br label %continue
224     break:
225       br label %end
226     continue:
227       br label %header
228     end:
229       ret void
230     }
231   )";
232 
233   run(Assembly);
234   checkBasicBlockRank({{"entry", 0},
235                        {"header", 1},
236                        {"body", 2},
237                        {"inside_a", 3},
238                        {"inside_b", 4},
239                        {"inside_c", 5},
240                        {"inside_d", 5},
241                        {"continue", 6},
242                        {"break", 7},
243                        {"end", 8}});
244 }
245 
246 TEST_F(SPIRVPartialOrderingVisitorTest, LoopNested) {
247   StringRef Assembly = R"(
248     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
249     entry:
250       %1 = icmp ne i32 0, 0
251       br label %a
252     a:
253       br i1 %1, label %h, label %b
254     b:
255       br label %c
256     c:
257       br i1 %1, label %d, label %e
258     d:
259       br label %g
260     e:
261       br label %f
262     f:
263       br label %c
264     g:
265       br label %a
266     h:
267       ret void
268     }
269   )";
270 
271   run(Assembly);
272   checkBasicBlockRank({{"entry", 0},
273                        {"a", 1},
274                        {"b", 2},
275                        {"c", 3},
276                        {"e", 4},
277                        {"f", 5},
278                        {"d", 6},
279                        {"g", 7},
280                        {"h", 8}});
281 }
282 
283 TEST_F(SPIRVPartialOrderingVisitorTest, IfNested) {
284   StringRef Assembly = R"(
285     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
286     entry:
287       br i1 true, label %a, label %d
288     a:
289       br i1 true, label %b, label %c
290     b:
291       br label %c
292     c:
293       br label %j
294     d:
295       br i1 true, label %e, label %f
296     e:
297       br label %i
298     f:
299       br i1 true, label %g, label %h
300     g:
301       br label %h
302     h:
303       br label %i
304     i:
305       br label %j
306     j:
307       ret void
308     }
309   )";
310   run(Assembly);
311   checkBasicBlockRank({{"entry", 0},
312                        {"a", 1},
313                        {"d", 1},
314                        {"b", 2},
315                        {"e", 2},
316                        {"f", 2},
317                        {"c", 3},
318                        {"g", 3},
319                        {"h", 4},
320                        {"i", 5},
321                        {"j", 6}});
322 }
323 
324 TEST_F(SPIRVPartialOrderingVisitorTest, CheckDeathIrreducible) {
325   StringRef Assembly = R"(
326     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
327     entry:
328       %1 = icmp ne i32 0, 0
329       br label %a
330     b:
331       br i1 %1, label %a, label %c
332     c:
333       br label %b
334     a:
335       br i1 %1, label %b, label %c
336     }
337   )";
338 
339   ASSERT_DEATH(
340       { run(Assembly); },
341       "No valid candidate in the queue. Is the graph reducible?");
342 }
343