xref: /llvm-project/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp (revision 5cb2db3b51c2a9d516d57bd2f07d9899bd5fdae7)
1 //===- SchedulerTest.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/AliasAnalysis.h"
12 #include "llvm/Analysis/AssumptionCache.h"
13 #include "llvm/Analysis/BasicAliasAnalysis.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/SandboxIR/Context.h"
18 #include "llvm/SandboxIR/Function.h"
19 #include "llvm/SandboxIR/Instruction.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "gmock/gmock-matchers.h"
22 #include "gtest/gtest.h"
23 
24 using namespace llvm;
25 
26 struct SchedulerTest : public testing::Test {
27   LLVMContext C;
28   std::unique_ptr<Module> M;
29   std::unique_ptr<AssumptionCache> AC;
30   std::unique_ptr<DominatorTree> DT;
31   std::unique_ptr<BasicAAResult> BAA;
32   std::unique_ptr<AAResults> AA;
33 
34   void parseIR(LLVMContext &C, const char *IR) {
35     SMDiagnostic Err;
36     M = parseAssemblyString(IR, Err, C);
37     if (!M)
38       Err.print("SchedulerTest", errs());
39   }
40 
41   AAResults &getAA(llvm::Function &LLVMF) {
42     TargetLibraryInfoImpl TLII;
43     TargetLibraryInfo TLI(TLII);
44     AA = std::make_unique<AAResults>(TLI);
45     AC = std::make_unique<AssumptionCache>(LLVMF);
46     DT = std::make_unique<DominatorTree>(LLVMF);
47     BAA = std::make_unique<BasicAAResult>(M->getDataLayout(), LLVMF, TLI, *AC,
48                                           DT.get());
49     AA->addAAResult(*BAA);
50     return *AA;
51   }
52 };
53 
54 static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F,
55                                                   StringRef Name) {
56   for (sandboxir::BasicBlock &BB : *F)
57     if (BB.getName() == Name)
58       return &BB;
59   llvm_unreachable("Expected to find basic block!");
60 }
61 
62 TEST_F(SchedulerTest, SchedBundle) {
63   parseIR(C, R"IR(
64 define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
65   store i8 %v0, ptr %ptr
66   %other = add i8 %v0, %v1
67   store i8 %v1, ptr %ptr
68   ret void
69 }
70 )IR");
71   llvm::Function *LLVMF = &*M->getFunction("foo");
72   sandboxir::Context Ctx(C);
73   auto *F = Ctx.createFunction(LLVMF);
74   auto *BB = &*F->begin();
75   auto It = BB->begin();
76   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
77   auto *Other = &*It++;
78   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
79   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
80 
81   sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
82   DAG.extend({&*BB->begin(), BB->getTerminator()});
83   auto *SN0 = DAG.getNode(S0);
84   auto *SN1 = DAG.getNode(S1);
85   sandboxir::SchedBundle Bndl({SN0, SN1});
86 
87   // Check getTop().
88   EXPECT_EQ(Bndl.getTop(), SN0);
89   // Check getBot().
90   EXPECT_EQ(Bndl.getBot(), SN1);
91   // Check cluster().
92   Bndl.cluster(S1->getIterator());
93   {
94     auto It = BB->begin();
95     EXPECT_EQ(&*It++, Other);
96     EXPECT_EQ(&*It++, S0);
97     EXPECT_EQ(&*It++, S1);
98     EXPECT_EQ(&*It++, Ret);
99     S0->moveBefore(Other);
100   }
101 
102   Bndl.cluster(S0->getIterator());
103   {
104     auto It = BB->begin();
105     EXPECT_EQ(&*It++, S0);
106     EXPECT_EQ(&*It++, S1);
107     EXPECT_EQ(&*It++, Other);
108     EXPECT_EQ(&*It++, Ret);
109     S1->moveAfter(Other);
110   }
111 
112   Bndl.cluster(Other->getIterator());
113   {
114     auto It = BB->begin();
115     EXPECT_EQ(&*It++, S0);
116     EXPECT_EQ(&*It++, S1);
117     EXPECT_EQ(&*It++, Other);
118     EXPECT_EQ(&*It++, Ret);
119     S1->moveAfter(Other);
120   }
121 
122   Bndl.cluster(Ret->getIterator());
123   {
124     auto It = BB->begin();
125     EXPECT_EQ(&*It++, Other);
126     EXPECT_EQ(&*It++, S0);
127     EXPECT_EQ(&*It++, S1);
128     EXPECT_EQ(&*It++, Ret);
129     Other->moveBefore(S1);
130   }
131 
132   Bndl.cluster(BB->end());
133   {
134     auto It = BB->begin();
135     EXPECT_EQ(&*It++, Other);
136     EXPECT_EQ(&*It++, Ret);
137     EXPECT_EQ(&*It++, S0);
138     EXPECT_EQ(&*It++, S1);
139     Ret->moveAfter(S1);
140     Other->moveAfter(S0);
141   }
142   // Check iterators.
143   EXPECT_THAT(Bndl, testing::ElementsAre(SN0, SN1));
144   EXPECT_THAT((const sandboxir::SchedBundle &)Bndl,
145               testing::ElementsAre(SN0, SN1));
146 }
147 
148 TEST_F(SchedulerTest, Basic) {
149   parseIR(C, R"IR(
150 define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
151   store i8 %v0, ptr %ptr
152   store i8 %v1, ptr %ptr
153   ret void
154 }
155 )IR");
156   llvm::Function *LLVMF = &*M->getFunction("foo");
157   sandboxir::Context Ctx(C);
158   auto *F = Ctx.createFunction(LLVMF);
159   auto *BB = &*F->begin();
160   auto It = BB->begin();
161   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
162   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
163   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
164 
165   {
166     // Schedule all instructions in sequence.
167     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
168     EXPECT_TRUE(Sched.trySchedule({Ret}));
169     EXPECT_TRUE(Sched.trySchedule({S1}));
170     EXPECT_TRUE(Sched.trySchedule({S0}));
171   }
172   {
173     // Skip instructions.
174     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
175     EXPECT_TRUE(Sched.trySchedule({Ret}));
176     EXPECT_TRUE(Sched.trySchedule({S0}));
177   }
178   {
179     // Try invalid scheduling. Dependency S0->S1.
180     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
181     EXPECT_TRUE(Sched.trySchedule({Ret}));
182     EXPECT_FALSE(Sched.trySchedule({S0, S1}));
183   }
184 }
185 
186 TEST_F(SchedulerTest, Bundles) {
187   parseIR(C, R"IR(
188 define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
189   %ld0 = load i8, ptr %ptr0
190   %ld1 = load i8, ptr %ptr1
191   store i8 %ld0, ptr %ptr0
192   store i8 %ld1, ptr %ptr1
193   ret void
194 }
195 )IR");
196   llvm::Function *LLVMF = &*M->getFunction("foo");
197   sandboxir::Context Ctx(C);
198   auto *F = Ctx.createFunction(LLVMF);
199   auto *BB = &*F->begin();
200   auto It = BB->begin();
201   auto *L0 = cast<sandboxir::LoadInst>(&*It++);
202   auto *L1 = cast<sandboxir::LoadInst>(&*It++);
203   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
204   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
205   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
206 
207   sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
208   EXPECT_TRUE(Sched.trySchedule({Ret}));
209   EXPECT_TRUE(Sched.trySchedule({S0, S1}));
210   EXPECT_TRUE(Sched.trySchedule({L0, L1}));
211 }
212 
213 TEST_F(SchedulerTest, RescheduleAlreadyScheduled) {
214   parseIR(C, R"IR(
215 define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
216   %ld0 = load i8, ptr %ptr0
217   %ld1 = load i8, ptr %ptr1
218   %add0 = add i8 %ld0, %ld0
219   %add1 = add i8 %ld1, %ld1
220   store i8 %add0, ptr %ptr0
221   store i8 %add1, ptr %ptr1
222   ret void
223 }
224 )IR");
225   llvm::Function *LLVMF = &*M->getFunction("foo");
226   sandboxir::Context Ctx(C);
227   auto *F = Ctx.createFunction(LLVMF);
228   auto *BB = &*F->begin();
229   auto It = BB->begin();
230   auto *L0 = cast<sandboxir::LoadInst>(&*It++);
231   auto *L1 = cast<sandboxir::LoadInst>(&*It++);
232   auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
233   auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
234   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
235   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
236   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
237 
238   sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
239   EXPECT_TRUE(Sched.trySchedule({Ret}));
240   EXPECT_TRUE(Sched.trySchedule({S0, S1}));
241   EXPECT_TRUE(Sched.trySchedule({L0, L1}));
242   // At this point Add0 and Add1 should have been individually scheduled
243   // as single bundles.
244   // Check if rescheduling works.
245   EXPECT_TRUE(Sched.trySchedule({Add0, Add1}));
246   EXPECT_TRUE(Sched.trySchedule({L0, L1}));
247 }
248 
249 TEST_F(SchedulerTest, DontCrossBBs) {
250   parseIR(C, R"IR(
251 define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
252 bb0:
253   %add0 = add i8 %v0, 0
254   %add1 = add i8 %v1, 1
255   br label %bb1
256 bb1:
257   store i8 %add0, ptr %ptr0
258   store i8 %add1, ptr %ptr1
259   ret void
260 }
261 )IR");
262   llvm::Function *LLVMF = &*M->getFunction("foo");
263   sandboxir::Context Ctx(C);
264   auto *F = Ctx.createFunction(LLVMF);
265   auto *BB0 = getBasicBlockByName(F, "bb0");
266   auto *BB1 = getBasicBlockByName(F, "bb1");
267   auto It = BB0->begin();
268   auto *Add0 = &*It++;
269   auto *Add1 = &*It++;
270 
271   It = BB1->begin();
272   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
273   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
274   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
275 
276   {
277     // Schedule bottom-up
278     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
279     EXPECT_TRUE(Sched.trySchedule({Ret}));
280     EXPECT_TRUE(Sched.trySchedule({S0, S1}));
281     // Scheduling across blocks should fail.
282     EXPECT_FALSE(Sched.trySchedule({Add0, Add1}));
283   }
284   {
285     // Schedule top-down
286     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
287     EXPECT_TRUE(Sched.trySchedule({Add0, Add1}));
288     // Scheduling across blocks should fail.
289     EXPECT_FALSE(Sched.trySchedule({S0, S1}));
290   }
291 }
292