xref: /llvm-project/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp (revision 5cb2db3b51c2a9d516d57bd2f07d9899bd5fdae7)
11d09925bSvporpo //===- SchedulerTest.cpp --------------------------------------------------===//
21d09925bSvporpo //
31d09925bSvporpo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41d09925bSvporpo // See https://llvm.org/LICENSE.txt for license information.
51d09925bSvporpo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61d09925bSvporpo //
71d09925bSvporpo //===----------------------------------------------------------------------===//
81d09925bSvporpo 
91d09925bSvporpo #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
101d09925bSvporpo #include "llvm/ADT/SmallVector.h"
111d09925bSvporpo #include "llvm/Analysis/AliasAnalysis.h"
121d09925bSvporpo #include "llvm/Analysis/AssumptionCache.h"
131d09925bSvporpo #include "llvm/Analysis/BasicAliasAnalysis.h"
141d09925bSvporpo #include "llvm/Analysis/TargetLibraryInfo.h"
151d09925bSvporpo #include "llvm/AsmParser/Parser.h"
161d09925bSvporpo #include "llvm/IR/Dominators.h"
171d09925bSvporpo #include "llvm/SandboxIR/Context.h"
181d09925bSvporpo #include "llvm/SandboxIR/Function.h"
191d09925bSvporpo #include "llvm/SandboxIR/Instruction.h"
201d09925bSvporpo #include "llvm/Support/SourceMgr.h"
211d09925bSvporpo #include "gmock/gmock-matchers.h"
221d09925bSvporpo #include "gtest/gtest.h"
231d09925bSvporpo 
241d09925bSvporpo using namespace llvm;
251d09925bSvporpo 
261d09925bSvporpo struct SchedulerTest : public testing::Test {
271d09925bSvporpo   LLVMContext C;
281d09925bSvporpo   std::unique_ptr<Module> M;
291d09925bSvporpo   std::unique_ptr<AssumptionCache> AC;
301d09925bSvporpo   std::unique_ptr<DominatorTree> DT;
311d09925bSvporpo   std::unique_ptr<BasicAAResult> BAA;
321d09925bSvporpo   std::unique_ptr<AAResults> AA;
331d09925bSvporpo 
341d09925bSvporpo   void parseIR(LLVMContext &C, const char *IR) {
351d09925bSvporpo     SMDiagnostic Err;
361d09925bSvporpo     M = parseAssemblyString(IR, Err, C);
371d09925bSvporpo     if (!M)
381d09925bSvporpo       Err.print("SchedulerTest", errs());
391d09925bSvporpo   }
401d09925bSvporpo 
411d09925bSvporpo   AAResults &getAA(llvm::Function &LLVMF) {
421d09925bSvporpo     TargetLibraryInfoImpl TLII;
431d09925bSvporpo     TargetLibraryInfo TLI(TLII);
441d09925bSvporpo     AA = std::make_unique<AAResults>(TLI);
451d09925bSvporpo     AC = std::make_unique<AssumptionCache>(LLVMF);
461d09925bSvporpo     DT = std::make_unique<DominatorTree>(LLVMF);
471d09925bSvporpo     BAA = std::make_unique<BasicAAResult>(M->getDataLayout(), LLVMF, TLI, *AC,
481d09925bSvporpo                                           DT.get());
491d09925bSvporpo     AA->addAAResult(*BAA);
501d09925bSvporpo     return *AA;
511d09925bSvporpo   }
521d09925bSvporpo };
531d09925bSvporpo 
54*5cb2db3bSvporpo static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F,
55*5cb2db3bSvporpo                                                   StringRef Name) {
56*5cb2db3bSvporpo   for (sandboxir::BasicBlock &BB : *F)
57*5cb2db3bSvporpo     if (BB.getName() == Name)
58*5cb2db3bSvporpo       return &BB;
59*5cb2db3bSvporpo   llvm_unreachable("Expected to find basic block!");
60*5cb2db3bSvporpo }
61*5cb2db3bSvporpo 
621d09925bSvporpo TEST_F(SchedulerTest, SchedBundle) {
631d09925bSvporpo   parseIR(C, R"IR(
641d09925bSvporpo define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
651d09925bSvporpo   store i8 %v0, ptr %ptr
661d09925bSvporpo   %other = add i8 %v0, %v1
671d09925bSvporpo   store i8 %v1, ptr %ptr
681d09925bSvporpo   ret void
691d09925bSvporpo }
701d09925bSvporpo )IR");
711d09925bSvporpo   llvm::Function *LLVMF = &*M->getFunction("foo");
721d09925bSvporpo   sandboxir::Context Ctx(C);
731d09925bSvporpo   auto *F = Ctx.createFunction(LLVMF);
741d09925bSvporpo   auto *BB = &*F->begin();
751d09925bSvporpo   auto It = BB->begin();
761d09925bSvporpo   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
771d09925bSvporpo   auto *Other = &*It++;
781d09925bSvporpo   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
791d09925bSvporpo   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
801d09925bSvporpo 
8131a4d2c2Svporpo   sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
821d09925bSvporpo   DAG.extend({&*BB->begin(), BB->getTerminator()});
831d09925bSvporpo   auto *SN0 = DAG.getNode(S0);
841d09925bSvporpo   auto *SN1 = DAG.getNode(S1);
851d09925bSvporpo   sandboxir::SchedBundle Bndl({SN0, SN1});
861d09925bSvporpo 
871d09925bSvporpo   // Check getTop().
881d09925bSvporpo   EXPECT_EQ(Bndl.getTop(), SN0);
891d09925bSvporpo   // Check getBot().
901d09925bSvporpo   EXPECT_EQ(Bndl.getBot(), SN1);
911d09925bSvporpo   // Check cluster().
921d09925bSvporpo   Bndl.cluster(S1->getIterator());
931d09925bSvporpo   {
941d09925bSvporpo     auto It = BB->begin();
951d09925bSvporpo     EXPECT_EQ(&*It++, Other);
961d09925bSvporpo     EXPECT_EQ(&*It++, S0);
971d09925bSvporpo     EXPECT_EQ(&*It++, S1);
981d09925bSvporpo     EXPECT_EQ(&*It++, Ret);
991d09925bSvporpo     S0->moveBefore(Other);
1001d09925bSvporpo   }
1011d09925bSvporpo 
1021d09925bSvporpo   Bndl.cluster(S0->getIterator());
1031d09925bSvporpo   {
1041d09925bSvporpo     auto It = BB->begin();
1051d09925bSvporpo     EXPECT_EQ(&*It++, S0);
1061d09925bSvporpo     EXPECT_EQ(&*It++, S1);
1071d09925bSvporpo     EXPECT_EQ(&*It++, Other);
1081d09925bSvporpo     EXPECT_EQ(&*It++, Ret);
1091d09925bSvporpo     S1->moveAfter(Other);
1101d09925bSvporpo   }
1111d09925bSvporpo 
1121d09925bSvporpo   Bndl.cluster(Other->getIterator());
1131d09925bSvporpo   {
1141d09925bSvporpo     auto It = BB->begin();
1151d09925bSvporpo     EXPECT_EQ(&*It++, S0);
1161d09925bSvporpo     EXPECT_EQ(&*It++, S1);
1171d09925bSvporpo     EXPECT_EQ(&*It++, Other);
1181d09925bSvporpo     EXPECT_EQ(&*It++, Ret);
1191d09925bSvporpo     S1->moveAfter(Other);
1201d09925bSvporpo   }
1211d09925bSvporpo 
1221d09925bSvporpo   Bndl.cluster(Ret->getIterator());
1231d09925bSvporpo   {
1241d09925bSvporpo     auto It = BB->begin();
1251d09925bSvporpo     EXPECT_EQ(&*It++, Other);
1261d09925bSvporpo     EXPECT_EQ(&*It++, S0);
1271d09925bSvporpo     EXPECT_EQ(&*It++, S1);
1281d09925bSvporpo     EXPECT_EQ(&*It++, Ret);
1291d09925bSvporpo     Other->moveBefore(S1);
1301d09925bSvporpo   }
1311d09925bSvporpo 
1321d09925bSvporpo   Bndl.cluster(BB->end());
1331d09925bSvporpo   {
1341d09925bSvporpo     auto It = BB->begin();
1351d09925bSvporpo     EXPECT_EQ(&*It++, Other);
1361d09925bSvporpo     EXPECT_EQ(&*It++, Ret);
1371d09925bSvporpo     EXPECT_EQ(&*It++, S0);
1381d09925bSvporpo     EXPECT_EQ(&*It++, S1);
1391d09925bSvporpo     Ret->moveAfter(S1);
1401d09925bSvporpo     Other->moveAfter(S0);
1411d09925bSvporpo   }
1421d09925bSvporpo   // Check iterators.
1431d09925bSvporpo   EXPECT_THAT(Bndl, testing::ElementsAre(SN0, SN1));
1441d09925bSvporpo   EXPECT_THAT((const sandboxir::SchedBundle &)Bndl,
1451d09925bSvporpo               testing::ElementsAre(SN0, SN1));
1461d09925bSvporpo }
1471d09925bSvporpo 
1481d09925bSvporpo TEST_F(SchedulerTest, Basic) {
1491d09925bSvporpo   parseIR(C, R"IR(
1501d09925bSvporpo define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
1511d09925bSvporpo   store i8 %v0, ptr %ptr
1521d09925bSvporpo   store i8 %v1, ptr %ptr
1531d09925bSvporpo   ret void
1541d09925bSvporpo }
1551d09925bSvporpo )IR");
1561d09925bSvporpo   llvm::Function *LLVMF = &*M->getFunction("foo");
1571d09925bSvporpo   sandboxir::Context Ctx(C);
1581d09925bSvporpo   auto *F = Ctx.createFunction(LLVMF);
1591d09925bSvporpo   auto *BB = &*F->begin();
1601d09925bSvporpo   auto It = BB->begin();
1611d09925bSvporpo   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
1621d09925bSvporpo   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
1631d09925bSvporpo   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
1641d09925bSvporpo 
1651d09925bSvporpo   {
1661d09925bSvporpo     // Schedule all instructions in sequence.
1675942a99fSvporpo     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
1681d09925bSvporpo     EXPECT_TRUE(Sched.trySchedule({Ret}));
1691d09925bSvporpo     EXPECT_TRUE(Sched.trySchedule({S1}));
1701d09925bSvporpo     EXPECT_TRUE(Sched.trySchedule({S0}));
1711d09925bSvporpo   }
1721d09925bSvporpo   {
1731d09925bSvporpo     // Skip instructions.
1745942a99fSvporpo     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
1751d09925bSvporpo     EXPECT_TRUE(Sched.trySchedule({Ret}));
1761d09925bSvporpo     EXPECT_TRUE(Sched.trySchedule({S0}));
1771d09925bSvporpo   }
1781d09925bSvporpo   {
179f7ef7b2fSvporpo     // Try invalid scheduling. Dependency S0->S1.
1805942a99fSvporpo     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
1811d09925bSvporpo     EXPECT_TRUE(Sched.trySchedule({Ret}));
182f7ef7b2fSvporpo     EXPECT_FALSE(Sched.trySchedule({S0, S1}));
1831d09925bSvporpo   }
1841d09925bSvporpo }
1851d09925bSvporpo 
1861d09925bSvporpo TEST_F(SchedulerTest, Bundles) {
1871d09925bSvporpo   parseIR(C, R"IR(
1881d09925bSvporpo define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
1891d09925bSvporpo   %ld0 = load i8, ptr %ptr0
1901d09925bSvporpo   %ld1 = load i8, ptr %ptr1
1911d09925bSvporpo   store i8 %ld0, ptr %ptr0
1921d09925bSvporpo   store i8 %ld1, ptr %ptr1
1931d09925bSvporpo   ret void
1941d09925bSvporpo }
1951d09925bSvporpo )IR");
1961d09925bSvporpo   llvm::Function *LLVMF = &*M->getFunction("foo");
1971d09925bSvporpo   sandboxir::Context Ctx(C);
1981d09925bSvporpo   auto *F = Ctx.createFunction(LLVMF);
1991d09925bSvporpo   auto *BB = &*F->begin();
2001d09925bSvporpo   auto It = BB->begin();
2011d09925bSvporpo   auto *L0 = cast<sandboxir::LoadInst>(&*It++);
2021d09925bSvporpo   auto *L1 = cast<sandboxir::LoadInst>(&*It++);
2031d09925bSvporpo   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
2041d09925bSvporpo   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
2051d09925bSvporpo   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
2061d09925bSvporpo 
2075942a99fSvporpo   sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
2081d09925bSvporpo   EXPECT_TRUE(Sched.trySchedule({Ret}));
2091d09925bSvporpo   EXPECT_TRUE(Sched.trySchedule({S0, S1}));
2101d09925bSvporpo   EXPECT_TRUE(Sched.trySchedule({L0, L1}));
2111d09925bSvporpo }
212f7ef7b2fSvporpo 
213f7ef7b2fSvporpo TEST_F(SchedulerTest, RescheduleAlreadyScheduled) {
214f7ef7b2fSvporpo   parseIR(C, R"IR(
215f7ef7b2fSvporpo define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
216f7ef7b2fSvporpo   %ld0 = load i8, ptr %ptr0
217f7ef7b2fSvporpo   %ld1 = load i8, ptr %ptr1
218f7ef7b2fSvporpo   %add0 = add i8 %ld0, %ld0
219f7ef7b2fSvporpo   %add1 = add i8 %ld1, %ld1
220f7ef7b2fSvporpo   store i8 %add0, ptr %ptr0
221f7ef7b2fSvporpo   store i8 %add1, ptr %ptr1
222f7ef7b2fSvporpo   ret void
223f7ef7b2fSvporpo }
224f7ef7b2fSvporpo )IR");
225f7ef7b2fSvporpo   llvm::Function *LLVMF = &*M->getFunction("foo");
226f7ef7b2fSvporpo   sandboxir::Context Ctx(C);
227f7ef7b2fSvporpo   auto *F = Ctx.createFunction(LLVMF);
228f7ef7b2fSvporpo   auto *BB = &*F->begin();
229f7ef7b2fSvporpo   auto It = BB->begin();
230f7ef7b2fSvporpo   auto *L0 = cast<sandboxir::LoadInst>(&*It++);
231f7ef7b2fSvporpo   auto *L1 = cast<sandboxir::LoadInst>(&*It++);
232f7ef7b2fSvporpo   auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
233f7ef7b2fSvporpo   auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
234f7ef7b2fSvporpo   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
235f7ef7b2fSvporpo   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
236f7ef7b2fSvporpo   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
237f7ef7b2fSvporpo 
238f7ef7b2fSvporpo   sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
239f7ef7b2fSvporpo   EXPECT_TRUE(Sched.trySchedule({Ret}));
240f7ef7b2fSvporpo   EXPECT_TRUE(Sched.trySchedule({S0, S1}));
241f7ef7b2fSvporpo   EXPECT_TRUE(Sched.trySchedule({L0, L1}));
242f7ef7b2fSvporpo   // At this point Add0 and Add1 should have been individually scheduled
243f7ef7b2fSvporpo   // as single bundles.
244f7ef7b2fSvporpo   // Check if rescheduling works.
245f7ef7b2fSvporpo   EXPECT_TRUE(Sched.trySchedule({Add0, Add1}));
246f7ef7b2fSvporpo   EXPECT_TRUE(Sched.trySchedule({L0, L1}));
247f7ef7b2fSvporpo }
248*5cb2db3bSvporpo 
249*5cb2db3bSvporpo TEST_F(SchedulerTest, DontCrossBBs) {
250*5cb2db3bSvporpo   parseIR(C, R"IR(
251*5cb2db3bSvporpo define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
252*5cb2db3bSvporpo bb0:
253*5cb2db3bSvporpo   %add0 = add i8 %v0, 0
254*5cb2db3bSvporpo   %add1 = add i8 %v1, 1
255*5cb2db3bSvporpo   br label %bb1
256*5cb2db3bSvporpo bb1:
257*5cb2db3bSvporpo   store i8 %add0, ptr %ptr0
258*5cb2db3bSvporpo   store i8 %add1, ptr %ptr1
259*5cb2db3bSvporpo   ret void
260*5cb2db3bSvporpo }
261*5cb2db3bSvporpo )IR");
262*5cb2db3bSvporpo   llvm::Function *LLVMF = &*M->getFunction("foo");
263*5cb2db3bSvporpo   sandboxir::Context Ctx(C);
264*5cb2db3bSvporpo   auto *F = Ctx.createFunction(LLVMF);
265*5cb2db3bSvporpo   auto *BB0 = getBasicBlockByName(F, "bb0");
266*5cb2db3bSvporpo   auto *BB1 = getBasicBlockByName(F, "bb1");
267*5cb2db3bSvporpo   auto It = BB0->begin();
268*5cb2db3bSvporpo   auto *Add0 = &*It++;
269*5cb2db3bSvporpo   auto *Add1 = &*It++;
270*5cb2db3bSvporpo 
271*5cb2db3bSvporpo   It = BB1->begin();
272*5cb2db3bSvporpo   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
273*5cb2db3bSvporpo   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
274*5cb2db3bSvporpo   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
275*5cb2db3bSvporpo 
276*5cb2db3bSvporpo   {
277*5cb2db3bSvporpo     // Schedule bottom-up
278*5cb2db3bSvporpo     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
279*5cb2db3bSvporpo     EXPECT_TRUE(Sched.trySchedule({Ret}));
280*5cb2db3bSvporpo     EXPECT_TRUE(Sched.trySchedule({S0, S1}));
281*5cb2db3bSvporpo     // Scheduling across blocks should fail.
282*5cb2db3bSvporpo     EXPECT_FALSE(Sched.trySchedule({Add0, Add1}));
283*5cb2db3bSvporpo   }
284*5cb2db3bSvporpo   {
285*5cb2db3bSvporpo     // Schedule top-down
286*5cb2db3bSvporpo     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
287*5cb2db3bSvporpo     EXPECT_TRUE(Sched.trySchedule({Add0, Add1}));
288*5cb2db3bSvporpo     // Scheduling across blocks should fail.
289*5cb2db3bSvporpo     EXPECT_FALSE(Sched.trySchedule({S0, S1}));
290*5cb2db3bSvporpo   }
291*5cb2db3bSvporpo }
292