1 //===- SchedulerTest.cpp --------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/Analysis/AliasAnalysis.h" 12 #include "llvm/Analysis/AssumptionCache.h" 13 #include "llvm/Analysis/BasicAliasAnalysis.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/AsmParser/Parser.h" 16 #include "llvm/IR/Dominators.h" 17 #include "llvm/SandboxIR/Context.h" 18 #include "llvm/SandboxIR/Function.h" 19 #include "llvm/SandboxIR/Instruction.h" 20 #include "llvm/Support/SourceMgr.h" 21 #include "gmock/gmock-matchers.h" 22 #include "gtest/gtest.h" 23 24 using namespace llvm; 25 26 struct SchedulerTest : public testing::Test { 27 LLVMContext C; 28 std::unique_ptr<Module> M; 29 std::unique_ptr<AssumptionCache> AC; 30 std::unique_ptr<DominatorTree> DT; 31 std::unique_ptr<BasicAAResult> BAA; 32 std::unique_ptr<AAResults> AA; 33 34 void parseIR(LLVMContext &C, const char *IR) { 35 SMDiagnostic Err; 36 M = parseAssemblyString(IR, Err, C); 37 if (!M) 38 Err.print("SchedulerTest", errs()); 39 } 40 41 AAResults &getAA(llvm::Function &LLVMF) { 42 TargetLibraryInfoImpl TLII; 43 TargetLibraryInfo TLI(TLII); 44 AA = std::make_unique<AAResults>(TLI); 45 AC = std::make_unique<AssumptionCache>(LLVMF); 46 DT = std::make_unique<DominatorTree>(LLVMF); 47 BAA = std::make_unique<BasicAAResult>(M->getDataLayout(), LLVMF, TLI, *AC, 48 DT.get()); 49 AA->addAAResult(*BAA); 50 return *AA; 51 } 52 }; 53 54 TEST_F(SchedulerTest, SchedBundle) { 55 parseIR(C, R"IR( 56 define void @foo(ptr %ptr, i8 %v0, i8 %v1) { 57 store i8 %v0, ptr %ptr 58 %other = add i8 %v0, %v1 59 store i8 %v1, ptr %ptr 60 ret void 61 } 62 )IR"); 63 llvm::Function *LLVMF = &*M->getFunction("foo"); 64 sandboxir::Context Ctx(C); 65 auto *F = Ctx.createFunction(LLVMF); 66 auto *BB = &*F->begin(); 67 auto It = BB->begin(); 68 auto *S0 = cast<sandboxir::StoreInst>(&*It++); 69 auto *Other = &*It++; 70 auto *S1 = cast<sandboxir::StoreInst>(&*It++); 71 auto *Ret = cast<sandboxir::ReturnInst>(&*It++); 72 73 sandboxir::DependencyGraph DAG(getAA(*LLVMF)); 74 DAG.extend({&*BB->begin(), BB->getTerminator()}); 75 auto *SN0 = DAG.getNode(S0); 76 auto *SN1 = DAG.getNode(S1); 77 sandboxir::SchedBundle Bndl({SN0, SN1}); 78 79 // Check getTop(). 80 EXPECT_EQ(Bndl.getTop(), SN0); 81 // Check getBot(). 82 EXPECT_EQ(Bndl.getBot(), SN1); 83 // Check cluster(). 84 Bndl.cluster(S1->getIterator()); 85 { 86 auto It = BB->begin(); 87 EXPECT_EQ(&*It++, Other); 88 EXPECT_EQ(&*It++, S0); 89 EXPECT_EQ(&*It++, S1); 90 EXPECT_EQ(&*It++, Ret); 91 S0->moveBefore(Other); 92 } 93 94 Bndl.cluster(S0->getIterator()); 95 { 96 auto It = BB->begin(); 97 EXPECT_EQ(&*It++, S0); 98 EXPECT_EQ(&*It++, S1); 99 EXPECT_EQ(&*It++, Other); 100 EXPECT_EQ(&*It++, Ret); 101 S1->moveAfter(Other); 102 } 103 104 Bndl.cluster(Other->getIterator()); 105 { 106 auto It = BB->begin(); 107 EXPECT_EQ(&*It++, S0); 108 EXPECT_EQ(&*It++, S1); 109 EXPECT_EQ(&*It++, Other); 110 EXPECT_EQ(&*It++, Ret); 111 S1->moveAfter(Other); 112 } 113 114 Bndl.cluster(Ret->getIterator()); 115 { 116 auto It = BB->begin(); 117 EXPECT_EQ(&*It++, Other); 118 EXPECT_EQ(&*It++, S0); 119 EXPECT_EQ(&*It++, S1); 120 EXPECT_EQ(&*It++, Ret); 121 Other->moveBefore(S1); 122 } 123 124 Bndl.cluster(BB->end()); 125 { 126 auto It = BB->begin(); 127 EXPECT_EQ(&*It++, Other); 128 EXPECT_EQ(&*It++, Ret); 129 EXPECT_EQ(&*It++, S0); 130 EXPECT_EQ(&*It++, S1); 131 Ret->moveAfter(S1); 132 Other->moveAfter(S0); 133 } 134 // Check iterators. 135 EXPECT_THAT(Bndl, testing::ElementsAre(SN0, SN1)); 136 EXPECT_THAT((const sandboxir::SchedBundle &)Bndl, 137 testing::ElementsAre(SN0, SN1)); 138 } 139 140 TEST_F(SchedulerTest, Basic) { 141 parseIR(C, R"IR( 142 define void @foo(ptr %ptr, i8 %v0, i8 %v1) { 143 store i8 %v0, ptr %ptr 144 store i8 %v1, ptr %ptr 145 ret void 146 } 147 )IR"); 148 llvm::Function *LLVMF = &*M->getFunction("foo"); 149 sandboxir::Context Ctx(C); 150 auto *F = Ctx.createFunction(LLVMF); 151 auto *BB = &*F->begin(); 152 auto It = BB->begin(); 153 auto *S0 = cast<sandboxir::StoreInst>(&*It++); 154 auto *S1 = cast<sandboxir::StoreInst>(&*It++); 155 auto *Ret = cast<sandboxir::ReturnInst>(&*It++); 156 157 { 158 // Schedule all instructions in sequence. 159 sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); 160 EXPECT_TRUE(Sched.trySchedule({Ret})); 161 EXPECT_TRUE(Sched.trySchedule({S1})); 162 EXPECT_TRUE(Sched.trySchedule({S0})); 163 } 164 { 165 // Skip instructions. 166 sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); 167 EXPECT_TRUE(Sched.trySchedule({Ret})); 168 EXPECT_TRUE(Sched.trySchedule({S0})); 169 } 170 { 171 // Try invalid scheduling 172 sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); 173 EXPECT_TRUE(Sched.trySchedule({Ret})); 174 EXPECT_TRUE(Sched.trySchedule({S0})); 175 EXPECT_FALSE(Sched.trySchedule({S1})); 176 } 177 } 178 179 TEST_F(SchedulerTest, Bundles) { 180 parseIR(C, R"IR( 181 define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) { 182 %ld0 = load i8, ptr %ptr0 183 %ld1 = load i8, ptr %ptr1 184 store i8 %ld0, ptr %ptr0 185 store i8 %ld1, ptr %ptr1 186 ret void 187 } 188 )IR"); 189 llvm::Function *LLVMF = &*M->getFunction("foo"); 190 sandboxir::Context Ctx(C); 191 auto *F = Ctx.createFunction(LLVMF); 192 auto *BB = &*F->begin(); 193 auto It = BB->begin(); 194 auto *L0 = cast<sandboxir::LoadInst>(&*It++); 195 auto *L1 = cast<sandboxir::LoadInst>(&*It++); 196 auto *S0 = cast<sandboxir::StoreInst>(&*It++); 197 auto *S1 = cast<sandboxir::StoreInst>(&*It++); 198 auto *Ret = cast<sandboxir::ReturnInst>(&*It++); 199 200 sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); 201 EXPECT_TRUE(Sched.trySchedule({Ret})); 202 EXPECT_TRUE(Sched.trySchedule({S0, S1})); 203 EXPECT_TRUE(Sched.trySchedule({L0, L1})); 204 } 205