1*45b567beSNathan Gauër //===- SPIRVPartialOrderingVisitorTests.cpp ----------------------------===// 2*45b567beSNathan Gauër // 3*45b567beSNathan Gauër // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*45b567beSNathan Gauër // See https://llvm.org/LICENSE.txt for license information. 5*45b567beSNathan Gauër // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*45b567beSNathan Gauër // 7*45b567beSNathan Gauër //===----------------------------------------------------------------------===// 8*45b567beSNathan Gauër 9*45b567beSNathan Gauër #include "SPIRVUtils.h" 10*45b567beSNathan Gauër #include "llvm/Analysis/DominanceFrontier.h" 11*45b567beSNathan Gauër #include "llvm/Analysis/PostDominators.h" 12*45b567beSNathan Gauër #include "llvm/AsmParser/Parser.h" 13*45b567beSNathan Gauër #include "llvm/IR/Instructions.h" 14*45b567beSNathan Gauër #include "llvm/IR/LLVMContext.h" 15*45b567beSNathan Gauër #include "llvm/IR/LegacyPassManager.h" 16*45b567beSNathan Gauër #include "llvm/IR/Module.h" 17*45b567beSNathan Gauër #include "llvm/IR/PassInstrumentation.h" 18*45b567beSNathan Gauër #include "llvm/IR/Type.h" 19*45b567beSNathan Gauër #include "llvm/IR/TypedPointerType.h" 20*45b567beSNathan Gauër #include "llvm/Support/SourceMgr.h" 21*45b567beSNathan Gauër 22*45b567beSNathan Gauër #include "gmock/gmock.h" 23*45b567beSNathan Gauër #include "gtest/gtest.h" 24*45b567beSNathan Gauër #include <queue> 25*45b567beSNathan Gauër 26*45b567beSNathan Gauër using namespace llvm; 27*45b567beSNathan Gauër using namespace llvm::SPIRV; 28*45b567beSNathan Gauër 29*45b567beSNathan Gauër class SPIRVPartialOrderingVisitorTest : public testing::Test { 30*45b567beSNathan Gauër protected: 31*45b567beSNathan Gauër void TearDown() override { M.reset(); } 32*45b567beSNathan Gauër 33*45b567beSNathan Gauër void run(StringRef Assembly) { 34*45b567beSNathan Gauër assert(M == nullptr && 35*45b567beSNathan Gauër "Calling runAnalysis multiple times is unsafe. See getAnalysis()."); 36*45b567beSNathan Gauër 37*45b567beSNathan Gauër SMDiagnostic Error; 38*45b567beSNathan Gauër M = parseAssemblyString(Assembly, Error, Context); 39*45b567beSNathan Gauër assert(M && "Bad assembly. Bad test?"); 40*45b567beSNathan Gauër 41*45b567beSNathan Gauër llvm::Function *F = M->getFunction("main"); 42*45b567beSNathan Gauër Visitor = std::make_unique<PartialOrderingVisitor>(*F); 43*45b567beSNathan Gauër } 44*45b567beSNathan Gauër 45*45b567beSNathan Gauër void 46*45b567beSNathan Gauër checkBasicBlockRank(std::vector<std::pair<const char *, size_t>> &&Expected) { 47*45b567beSNathan Gauër llvm::Function *F = M->getFunction("main"); 48*45b567beSNathan Gauër auto It = Expected.begin(); 49*45b567beSNathan Gauër Visitor->partialOrderVisit(*F->begin(), [&](BasicBlock *BB) { 50*45b567beSNathan Gauër const auto &[Name, Rank] = *It; 51*45b567beSNathan Gauër EXPECT_TRUE(It != Expected.end()) 52*45b567beSNathan Gauër << "Unexpected block \"" << BB->getName() << " visited."; 53*45b567beSNathan Gauër EXPECT_TRUE(BB->getName() == Name) 54*45b567beSNathan Gauër << "Error: expected block \"" << Name << "\" got \"" << BB->getName() 55*45b567beSNathan Gauër << "\""; 56*45b567beSNathan Gauër EXPECT_EQ(Rank, Visitor->GetNodeRank(BB)) 57*45b567beSNathan Gauër << "Bad rank for BB \"" << BB->getName() << "\""; 58*45b567beSNathan Gauër It++; 59*45b567beSNathan Gauër return true; 60*45b567beSNathan Gauër }); 61*45b567beSNathan Gauër ASSERT_TRUE(It == Expected.end()) 62*45b567beSNathan Gauër << "Expected block \"" << It->first 63*45b567beSNathan Gauër << "\" but reached the end of the function instead."; 64*45b567beSNathan Gauër } 65*45b567beSNathan Gauër 66*45b567beSNathan Gauër protected: 67*45b567beSNathan Gauër LLVMContext Context; 68*45b567beSNathan Gauër std::unique_ptr<Module> M; 69*45b567beSNathan Gauër std::unique_ptr<PartialOrderingVisitor> Visitor; 70*45b567beSNathan Gauër }; 71*45b567beSNathan Gauër 72*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, EmptyFunction) { 73*45b567beSNathan Gauër StringRef Assembly = R"( 74*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 75*45b567beSNathan Gauër ret void 76*45b567beSNathan Gauër } 77*45b567beSNathan Gauër )"; 78*45b567beSNathan Gauër 79*45b567beSNathan Gauër run(Assembly); 80*45b567beSNathan Gauër checkBasicBlockRank({{"", 0}}); 81*45b567beSNathan Gauër } 82*45b567beSNathan Gauër 83*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, BasicBlockSwap) { 84*45b567beSNathan Gauër StringRef Assembly = R"( 85*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 86*45b567beSNathan Gauër entry: 87*45b567beSNathan Gauër br label %middle 88*45b567beSNathan Gauër exit: 89*45b567beSNathan Gauër ret void 90*45b567beSNathan Gauër middle: 91*45b567beSNathan Gauër br label %exit 92*45b567beSNathan Gauër } 93*45b567beSNathan Gauër )"; 94*45b567beSNathan Gauër 95*45b567beSNathan Gauër run(Assembly); 96*45b567beSNathan Gauër checkBasicBlockRank({{"entry", 0}, {"middle", 1}, {"exit", 2}}); 97*45b567beSNathan Gauër } 98*45b567beSNathan Gauër 99*45b567beSNathan Gauër // Skip condition: 100*45b567beSNathan Gauër // +-> A -+ 101*45b567beSNathan Gauër // entry -+ +-> C 102*45b567beSNathan Gauër // +------+ 103*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, SkipCondition) { 104*45b567beSNathan Gauër StringRef Assembly = R"( 105*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 106*45b567beSNathan Gauër entry: 107*45b567beSNathan Gauër %1 = icmp ne i32 0, 0 108*45b567beSNathan Gauër br i1 %1, label %c, label %a 109*45b567beSNathan Gauër c: 110*45b567beSNathan Gauër ret void 111*45b567beSNathan Gauër a: 112*45b567beSNathan Gauër br label %c 113*45b567beSNathan Gauër } 114*45b567beSNathan Gauër )"; 115*45b567beSNathan Gauër 116*45b567beSNathan Gauër run(Assembly); 117*45b567beSNathan Gauër checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"c", 2}}); 118*45b567beSNathan Gauër } 119*45b567beSNathan Gauër 120*45b567beSNathan Gauër // Simple loop: 121*45b567beSNathan Gauër // entry -> header <-----------------+ 122*45b567beSNathan Gauër // | `-> body -> continue -+ 123*45b567beSNathan Gauër // `-> end 124*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, LoopOrdering) { 125*45b567beSNathan Gauër StringRef Assembly = R"( 126*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 127*45b567beSNathan Gauër entry: 128*45b567beSNathan Gauër %1 = icmp ne i32 0, 0 129*45b567beSNathan Gauër br label %header 130*45b567beSNathan Gauër end: 131*45b567beSNathan Gauër ret void 132*45b567beSNathan Gauër body: 133*45b567beSNathan Gauër br label %continue 134*45b567beSNathan Gauër continue: 135*45b567beSNathan Gauër br label %header 136*45b567beSNathan Gauër header: 137*45b567beSNathan Gauër br i1 %1, label %body, label %end 138*45b567beSNathan Gauër } 139*45b567beSNathan Gauër )"; 140*45b567beSNathan Gauër 141*45b567beSNathan Gauër run(Assembly); 142*45b567beSNathan Gauër checkBasicBlockRank( 143*45b567beSNathan Gauër {{"entry", 0}, {"header", 1}, {"body", 2}, {"continue", 3}, {"end", 4}}); 144*45b567beSNathan Gauër } 145*45b567beSNathan Gauër 146*45b567beSNathan Gauër // Diamond condition: 147*45b567beSNathan Gauër // +-> A -+ 148*45b567beSNathan Gauër // entry -+ +-> C 149*45b567beSNathan Gauër // +-> B -+ 150*45b567beSNathan Gauër // 151*45b567beSNathan Gauër // A and B order can be flipped with no effect, but it must be remain 152*45b567beSNathan Gauër // deterministic/stable. 153*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, DiamondCondition) { 154*45b567beSNathan Gauër StringRef Assembly = R"( 155*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 156*45b567beSNathan Gauër entry: 157*45b567beSNathan Gauër %1 = icmp ne i32 0, 0 158*45b567beSNathan Gauër br i1 %1, label %a, label %b 159*45b567beSNathan Gauër c: 160*45b567beSNathan Gauër ret void 161*45b567beSNathan Gauër b: 162*45b567beSNathan Gauër br label %c 163*45b567beSNathan Gauër a: 164*45b567beSNathan Gauër br label %c 165*45b567beSNathan Gauër } 166*45b567beSNathan Gauër )"; 167*45b567beSNathan Gauër 168*45b567beSNathan Gauër run(Assembly); 169*45b567beSNathan Gauër checkBasicBlockRank({{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}}); 170*45b567beSNathan Gauër } 171*45b567beSNathan Gauër 172*45b567beSNathan Gauër // Crossing conditions: 173*45b567beSNathan Gauër // +------+ +-> C -+ 174*45b567beSNathan Gauër // +-> A -+ | | | 175*45b567beSNathan Gauër // entry -+ +--_|_-+ +-> E 176*45b567beSNathan Gauër // +-> B -+ | | 177*45b567beSNathan Gauër // +------+----> D -+ 178*45b567beSNathan Gauër // 179*45b567beSNathan Gauër // A & B have the same rank. 180*45b567beSNathan Gauër // C & D have the same rank, but are after A & B. 181*45b567beSNathan Gauër // E if the last block. 182*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, CrossingCondition) { 183*45b567beSNathan Gauër StringRef Assembly = R"( 184*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 185*45b567beSNathan Gauër entry: 186*45b567beSNathan Gauër %1 = icmp ne i32 0, 0 187*45b567beSNathan Gauër br i1 %1, label %a, label %b 188*45b567beSNathan Gauër e: 189*45b567beSNathan Gauër ret void 190*45b567beSNathan Gauër c: 191*45b567beSNathan Gauër br label %e 192*45b567beSNathan Gauër b: 193*45b567beSNathan Gauër br i1 %1, label %d, label %c 194*45b567beSNathan Gauër d: 195*45b567beSNathan Gauër br label %e 196*45b567beSNathan Gauër a: 197*45b567beSNathan Gauër br i1 %1, label %c, label %d 198*45b567beSNathan Gauër } 199*45b567beSNathan Gauër )"; 200*45b567beSNathan Gauër 201*45b567beSNathan Gauër run(Assembly); 202*45b567beSNathan Gauër checkBasicBlockRank( 203*45b567beSNathan Gauër {{"entry", 0}, {"a", 1}, {"b", 1}, {"c", 2}, {"d", 2}, {"e", 3}}); 204*45b567beSNathan Gauër } 205*45b567beSNathan Gauër 206*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, LoopDiamond) { 207*45b567beSNathan Gauër StringRef Assembly = R"( 208*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 209*45b567beSNathan Gauër entry: 210*45b567beSNathan Gauër %1 = icmp ne i32 0, 0 211*45b567beSNathan Gauër br label %header 212*45b567beSNathan Gauër header: 213*45b567beSNathan Gauër br i1 %1, label %body, label %end 214*45b567beSNathan Gauër body: 215*45b567beSNathan Gauër br i1 %1, label %inside_a, label %break 216*45b567beSNathan Gauër inside_a: 217*45b567beSNathan Gauër br label %inside_b 218*45b567beSNathan Gauër inside_b: 219*45b567beSNathan Gauër br i1 %1, label %inside_c, label %inside_d 220*45b567beSNathan Gauër inside_c: 221*45b567beSNathan Gauër br label %continue 222*45b567beSNathan Gauër inside_d: 223*45b567beSNathan Gauër br label %continue 224*45b567beSNathan Gauër break: 225*45b567beSNathan Gauër br label %end 226*45b567beSNathan Gauër continue: 227*45b567beSNathan Gauër br label %header 228*45b567beSNathan Gauër end: 229*45b567beSNathan Gauër ret void 230*45b567beSNathan Gauër } 231*45b567beSNathan Gauër )"; 232*45b567beSNathan Gauër 233*45b567beSNathan Gauër run(Assembly); 234*45b567beSNathan Gauër checkBasicBlockRank({{"entry", 0}, 235*45b567beSNathan Gauër {"header", 1}, 236*45b567beSNathan Gauër {"body", 2}, 237*45b567beSNathan Gauër {"inside_a", 3}, 238*45b567beSNathan Gauër {"inside_b", 4}, 239*45b567beSNathan Gauër {"inside_c", 5}, 240*45b567beSNathan Gauër {"inside_d", 5}, 241*45b567beSNathan Gauër {"continue", 6}, 242*45b567beSNathan Gauër {"break", 7}, 243*45b567beSNathan Gauër {"end", 8}}); 244*45b567beSNathan Gauër } 245*45b567beSNathan Gauër 246*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, LoopNested) { 247*45b567beSNathan Gauër StringRef Assembly = R"( 248*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 249*45b567beSNathan Gauër entry: 250*45b567beSNathan Gauër %1 = icmp ne i32 0, 0 251*45b567beSNathan Gauër br label %a 252*45b567beSNathan Gauër a: 253*45b567beSNathan Gauër br i1 %1, label %h, label %b 254*45b567beSNathan Gauër b: 255*45b567beSNathan Gauër br label %c 256*45b567beSNathan Gauër c: 257*45b567beSNathan Gauër br i1 %1, label %d, label %e 258*45b567beSNathan Gauër d: 259*45b567beSNathan Gauër br label %g 260*45b567beSNathan Gauër e: 261*45b567beSNathan Gauër br label %f 262*45b567beSNathan Gauër f: 263*45b567beSNathan Gauër br label %c 264*45b567beSNathan Gauër g: 265*45b567beSNathan Gauër br label %a 266*45b567beSNathan Gauër h: 267*45b567beSNathan Gauër ret void 268*45b567beSNathan Gauër } 269*45b567beSNathan Gauër )"; 270*45b567beSNathan Gauër 271*45b567beSNathan Gauër run(Assembly); 272*45b567beSNathan Gauër checkBasicBlockRank({{"entry", 0}, 273*45b567beSNathan Gauër {"a", 1}, 274*45b567beSNathan Gauër {"b", 2}, 275*45b567beSNathan Gauër {"c", 3}, 276*45b567beSNathan Gauër {"e", 4}, 277*45b567beSNathan Gauër {"f", 5}, 278*45b567beSNathan Gauër {"d", 6}, 279*45b567beSNathan Gauër {"g", 7}, 280*45b567beSNathan Gauër {"h", 8}}); 281*45b567beSNathan Gauër } 282*45b567beSNathan Gauër 283*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, IfNested) { 284*45b567beSNathan Gauër StringRef Assembly = R"( 285*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 286*45b567beSNathan Gauër entry: 287*45b567beSNathan Gauër br i1 true, label %a, label %d 288*45b567beSNathan Gauër a: 289*45b567beSNathan Gauër br i1 true, label %b, label %c 290*45b567beSNathan Gauër b: 291*45b567beSNathan Gauër br label %c 292*45b567beSNathan Gauër c: 293*45b567beSNathan Gauër br label %j 294*45b567beSNathan Gauër d: 295*45b567beSNathan Gauër br i1 true, label %e, label %f 296*45b567beSNathan Gauër e: 297*45b567beSNathan Gauër br label %i 298*45b567beSNathan Gauër f: 299*45b567beSNathan Gauër br i1 true, label %g, label %h 300*45b567beSNathan Gauër g: 301*45b567beSNathan Gauër br label %h 302*45b567beSNathan Gauër h: 303*45b567beSNathan Gauër br label %i 304*45b567beSNathan Gauër i: 305*45b567beSNathan Gauër br label %j 306*45b567beSNathan Gauër j: 307*45b567beSNathan Gauër ret void 308*45b567beSNathan Gauër } 309*45b567beSNathan Gauër )"; 310*45b567beSNathan Gauër run(Assembly); 311*45b567beSNathan Gauër checkBasicBlockRank({{"entry", 0}, 312*45b567beSNathan Gauër {"a", 1}, 313*45b567beSNathan Gauër {"d", 1}, 314*45b567beSNathan Gauër {"b", 2}, 315*45b567beSNathan Gauër {"e", 2}, 316*45b567beSNathan Gauër {"f", 2}, 317*45b567beSNathan Gauër {"c", 3}, 318*45b567beSNathan Gauër {"g", 3}, 319*45b567beSNathan Gauër {"h", 4}, 320*45b567beSNathan Gauër {"i", 5}, 321*45b567beSNathan Gauër {"j", 6}}); 322*45b567beSNathan Gauër } 323*45b567beSNathan Gauër 324*45b567beSNathan Gauër TEST_F(SPIRVPartialOrderingVisitorTest, CheckDeathIrreducible) { 325*45b567beSNathan Gauër StringRef Assembly = R"( 326*45b567beSNathan Gauër define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" { 327*45b567beSNathan Gauër entry: 328*45b567beSNathan Gauër %1 = icmp ne i32 0, 0 329*45b567beSNathan Gauër br label %a 330*45b567beSNathan Gauër b: 331*45b567beSNathan Gauër br i1 %1, label %a, label %c 332*45b567beSNathan Gauër c: 333*45b567beSNathan Gauër br label %b 334*45b567beSNathan Gauër a: 335*45b567beSNathan Gauër br i1 %1, label %b, label %c 336*45b567beSNathan Gauër } 337*45b567beSNathan Gauër )"; 338*45b567beSNathan Gauër 339*45b567beSNathan Gauër ASSERT_DEATH( 340*45b567beSNathan Gauër { run(Assembly); }, 341*45b567beSNathan Gauër "No valid candidate in the queue. Is the graph reducible?"); 342*45b567beSNathan Gauër } 343