1 //===- SPIRVPartialOrderingVisitorTests.cpp ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "SPIRVUtils.h" 10 #include "llvm/Analysis/DominanceFrontier.h" 11 #include "llvm/Analysis/PostDominators.h" 12 #include "llvm/AsmParser/Parser.h" 13 #include "llvm/IR/Instructions.h" 14 #include "llvm/IR/LLVMContext.h" 15 #include "llvm/IR/LegacyPassManager.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/IR/PassInstrumentation.h" 18 #include "llvm/IR/Type.h" 19 #include "llvm/IR/TypedPointerType.h" 20 #include "llvm/Support/SourceMgr.h" 21 22 #include "gmock/gmock.h" 23 #include "gtest/gtest.h" 24 #include <queue> 25 26 using namespace llvm; 27 using namespace llvm::SPIRV; 28 29 class SPIRVPartialOrderingVisitorTest : public testing::Test { 30 protected: 31 void TearDown() override { M.reset(); } 32 33 void run(StringRef Assembly) { 34 assert(M == nullptr && 35 "Calling runAnalysis multiple times is unsafe. See getAnalysis()."); 36 37 SMDiagnostic Error; 38 M = parseAssemblyString(Assembly, Error, Context); 39 assert(M && "Bad assembly. Bad test?"); 40 41 llvm::Function *F = M->getFunction("main"); 42 Visitor = std::make_unique<PartialOrderingVisitor>(*F); 43 } 44 45 void 46 checkBasicBlockRank(std::vector<std::pair<const char *, size_t>> &&Expected) { 47 llvm::Function *F = M->getFunction("main"); 48 auto It = Expected.begin(); 49 Visitor->partialOrderVisit(*F->begin(), [&](BasicBlock *BB) { 50 const auto &[Name, Rank] = *It; 51 EXPECT_TRUE(It != Expected.end()) 52 << "Unexpected block \"" << BB->getName() << " visited."; 53 EXPECT_TRUE(BB->getName() == Name) 54 << "Error: expected block \"" << Name << "\" got \"" << BB->getName() 55 << "\""; 56 EXPECT_EQ(Rank, Visitor->GetNodeRank(BB)) 57 << "Bad rank for BB \"" << BB->getName() << "\""; 58 It++; 59 return true; 60 }); 61 ASSERT_TRUE(It == Expected.end()) 62 << "Expected block \"" << It->first 63 << "\" but reached the end of the function instead."; 64 } 65 66 protected: 67 LLVMContext Context; 68 std::unique_ptr<Module> M; 69 std::unique_ptr<PartialOrderingVisitor> Visitor; 70 }; 71 72 TEST_F(SPIRVPartialOrderingVisitorTest, EmptyFunction) { 73 StringRef Assembly = R"( 74 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 75 ret void 76 } 77 )"; 78 79 run(Assembly); 80 checkBasicBlockRank({{"", 0}}); 81 } 82 83 TEST_F(SPIRVPartialOrderingVisitorTest, BasicBlockSwap) { 84 StringRef Assembly = R"( 85 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 86 entry: 87 br label %middle 88 exit: 89 ret void 90 middle: 91 br label %exit 92 } 93 )"; 94 95 run(Assembly); 96 checkBasicBlockRank({{"entry", 0}, {"middle", 1}, {"exit", 2}}); 97 } 98 99 // Skip condition: 100 // +-> A -+ 101 // entry -+ +-> C 102 // +------+ 103 TEST_F(SPIRVPartialOrderingVisitorTest, SkipCondition) { 104 StringRef Assembly = R"( 105 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 106 entry: 107 %1 = icmp ne i32 0, 0 108 br i1 %1, label %c, label %a 109 c: 110 ret void 111 a: 112 br label %c 113 } 114 )"; 115 116 run(Assembly); 117 checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"c", 2}}); 118 } 119 120 // Simple loop: 121 // entry -> header <-----------------+ 122 // | `-> body -> continue -+ 123 // `-> end 124 TEST_F(SPIRVPartialOrderingVisitorTest, LoopOrdering) { 125 StringRef Assembly = R"( 126 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 127 entry: 128 %1 = icmp ne i32 0, 0 129 br label %header 130 end: 131 ret void 132 body: 133 br label %continue 134 continue: 135 br label %header 136 header: 137 br i1 %1, label %body, label %end 138 } 139 )"; 140 141 run(Assembly); 142 checkBasicBlockRank( 143 {{"entry", 0}, {"header", 1}, {"body", 2}, {"continue", 3}, {"end", 4}}); 144 } 145 146 // Diamond condition: 147 // +-> A -+ 148 // entry -+ +-> C 149 // +-> B -+ 150 // 151 // A and B order can be flipped with no effect, but it must be remain 152 // deterministic/stable. 153 TEST_F(SPIRVPartialOrderingVisitorTest, DiamondCondition) { 154 StringRef Assembly = R"( 155 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 156 entry: 157 %1 = icmp ne i32 0, 0 158 br i1 %1, label %a, label %b 159 c: 160 ret void 161 b: 162 br label %c 163 a: 164 br label %c 165 } 166 )"; 167 168 run(Assembly); 169 checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}}); 170 } 171 172 // Crossing conditions: 173 // +------+ +-> C -+ 174 // +-> A -+ | | | 175 // entry -+ +--_|_-+ +-> E 176 // +-> B -+ | | 177 // +------+----> D -+ 178 // 179 // A & B have the same rank. 180 // C & D have the same rank, but are after A & B. 181 // E if the last block. 182 TEST_F(SPIRVPartialOrderingVisitorTest, CrossingCondition) { 183 StringRef Assembly = R"( 184 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 185 entry: 186 %1 = icmp ne i32 0, 0 187 br i1 %1, label %a, label %b 188 e: 189 ret void 190 c: 191 br label %e 192 b: 193 br i1 %1, label %d, label %c 194 d: 195 br label %e 196 a: 197 br i1 %1, label %c, label %d 198 } 199 )"; 200 201 run(Assembly); 202 checkBasicBlockRank( 203 {{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}, {"d", 2}, {"e", 3}}); 204 } 205 206 TEST_F(SPIRVPartialOrderingVisitorTest, LoopDiamond) { 207 StringRef Assembly = R"( 208 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 209 entry: 210 %1 = icmp ne i32 0, 0 211 br label %header 212 header: 213 br i1 %1, label %body, label %end 214 body: 215 br i1 %1, label %inside_a, label %break 216 inside_a: 217 br label %inside_b 218 inside_b: 219 br i1 %1, label %inside_c, label %inside_d 220 inside_c: 221 br label %continue 222 inside_d: 223 br label %continue 224 break: 225 br label %end 226 continue: 227 br label %header 228 end: 229 ret void 230 } 231 )"; 232 233 run(Assembly); 234 checkBasicBlockRank({{"entry", 0}, 235 {"header", 1}, 236 {"body", 2}, 237 {"inside_a", 3}, 238 {"inside_b", 4}, 239 {"inside_c", 5}, 240 {"inside_d", 5}, 241 {"continue", 6}, 242 {"break", 7}, 243 {"end", 8}}); 244 } 245 246 TEST_F(SPIRVPartialOrderingVisitorTest, LoopNested) { 247 StringRef Assembly = R"( 248 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 249 entry: 250 %1 = icmp ne i32 0, 0 251 br label %a 252 a: 253 br i1 %1, label %h, label %b 254 b: 255 br label %c 256 c: 257 br i1 %1, label %d, label %e 258 d: 259 br label %g 260 e: 261 br label %f 262 f: 263 br label %c 264 g: 265 br label %a 266 h: 267 ret void 268 } 269 )"; 270 271 run(Assembly); 272 checkBasicBlockRank({{"entry", 0}, 273 {"a", 1}, 274 {"b", 2}, 275 {"c", 3}, 276 {"e", 4}, 277 {"f", 5}, 278 {"d", 6}, 279 {"g", 7}, 280 {"h", 8}}); 281 } 282 283 TEST_F(SPIRVPartialOrderingVisitorTest, IfNested) { 284 StringRef Assembly = R"( 285 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 286 entry: 287 br i1 true, label %a, label %d 288 a: 289 br i1 true, label %b, label %c 290 b: 291 br label %c 292 c: 293 br label %j 294 d: 295 br i1 true, label %e, label %f 296 e: 297 br label %i 298 f: 299 br i1 true, label %g, label %h 300 g: 301 br label %h 302 h: 303 br label %i 304 i: 305 br label %j 306 j: 307 ret void 308 } 309 )"; 310 run(Assembly); 311 checkBasicBlockRank({{"entry", 0}, 312 {"a", 1}, 313 {"d", 1}, 314 {"b", 2}, 315 {"e", 2}, 316 {"f", 2}, 317 {"c", 3}, 318 {"g", 3}, 319 {"h", 4}, 320 {"i", 5}, 321 {"j", 6}}); 322 } 323 324 TEST_F(SPIRVPartialOrderingVisitorTest, CheckDeathIrreducible) { 325 StringRef Assembly = R"( 326 define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 327 entry: 328 %1 = icmp ne i32 0, 0 329 br label %a 330 b: 331 br i1 %1, label %a, label %c 332 c: 333 br label %b 334 a: 335 br i1 %1, label %b, label %c 336 } 337 )"; 338 339 ASSERT_DEATH( 340 { run(Assembly); }, 341 "No valid candidate in the queue. Is the graph reducible?"); 342 } 343