xref: /llvm-project/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (revision ca47f48a5c9e81ef8b5c4a5b1fbc473ea5d5497c)
1 //===- DependencyGraphTest.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/SandboxIR/Constant.h"
12 #include "llvm/SandboxIR/Context.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/Support/SourceMgr.h"
15 #include "gmock/gmock-matchers.h"
16 #include "gtest/gtest.h"
17 
18 using namespace llvm;
19 
20 struct DependencyGraphTest : public testing::Test {
21   LLVMContext C;
22   std::unique_ptr<Module> M;
23 
24   void parseIR(LLVMContext &C, const char *IR) {
25     SMDiagnostic Err;
26     M = parseAssemblyString(IR, Err, C);
27     if (!M)
28       Err.print("DependencyGraphTest", errs());
29   }
30 };
31 
32 TEST_F(DependencyGraphTest, DGNode_IsMem) {
33   parseIR(C, R"IR(
34 declare void @llvm.sideeffect()
35 declare void @llvm.pseudoprobe(i64, i64, i32, i64)
36 declare void @llvm.fake.use(...)
37 declare void @bar()
38 define void @foo(i8 %v1, ptr %ptr) {
39   store i8 %v1, ptr %ptr
40   %ld0 = load i8, ptr %ptr
41   %add = add i8 %v1, %v1
42   %stacksave = call ptr @llvm.stacksave()
43   call void @llvm.stackrestore(ptr %stacksave)
44   call void @llvm.sideeffect()
45   call void @llvm.pseudoprobe(i64 42, i64 1, i32 0, i64 -1)
46   call void @llvm.fake.use(ptr %ptr)
47   call void @bar()
48   ret void
49 }
50 )IR");
51   llvm::Function *LLVMF = &*M->getFunction("foo");
52   sandboxir::Context Ctx(C);
53   auto *F = Ctx.createFunction(LLVMF);
54   auto *BB = &*F->begin();
55   auto It = BB->begin();
56   auto *Store = cast<sandboxir::StoreInst>(&*It++);
57   auto *Load = cast<sandboxir::LoadInst>(&*It++);
58   auto *Add = cast<sandboxir::BinaryOperator>(&*It++);
59   auto *StackSave = cast<sandboxir::CallInst>(&*It++);
60   auto *StackRestore = cast<sandboxir::CallInst>(&*It++);
61   auto *SideEffect = cast<sandboxir::CallInst>(&*It++);
62   auto *PseudoProbe = cast<sandboxir::CallInst>(&*It++);
63   auto *FakeUse = cast<sandboxir::CallInst>(&*It++);
64   auto *Call = cast<sandboxir::CallInst>(&*It++);
65   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
66 
67   sandboxir::DependencyGraph DAG;
68   DAG.extend({&*BB->begin(), BB->getTerminator()});
69   EXPECT_TRUE(DAG.getNode(Store)->isMem());
70   EXPECT_TRUE(DAG.getNode(Load)->isMem());
71   EXPECT_FALSE(DAG.getNode(Add)->isMem());
72   EXPECT_TRUE(DAG.getNode(StackSave)->isMem());
73   EXPECT_TRUE(DAG.getNode(StackRestore)->isMem());
74   EXPECT_FALSE(DAG.getNode(SideEffect)->isMem());
75   EXPECT_FALSE(DAG.getNode(PseudoProbe)->isMem());
76   EXPECT_TRUE(DAG.getNode(FakeUse)->isMem());
77   EXPECT_TRUE(DAG.getNode(Call)->isMem());
78   EXPECT_FALSE(DAG.getNode(Ret)->isMem());
79 }
80 
81 TEST_F(DependencyGraphTest, Basic) {
82   parseIR(C, R"IR(
83 define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
84   store i8 %v0, ptr %ptr
85   store i8 %v1, ptr %ptr
86   ret void
87 }
88 )IR");
89   llvm::Function *LLVMF = &*M->getFunction("foo");
90   sandboxir::Context Ctx(C);
91   auto *F = Ctx.createFunction(LLVMF);
92   auto *BB = &*F->begin();
93   auto It = BB->begin();
94   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
95   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
96   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
97   sandboxir::DependencyGraph DAG;
98   auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()});
99   // Check extend().
100   EXPECT_EQ(Span.top(), &*BB->begin());
101   EXPECT_EQ(Span.bottom(), BB->getTerminator());
102 
103   sandboxir::DGNode *N0 = DAG.getNode(S0);
104   sandboxir::DGNode *N1 = DAG.getNode(S1);
105   sandboxir::DGNode *N2 = DAG.getNode(Ret);
106   // Check getInstruction().
107   EXPECT_EQ(N0->getInstruction(), S0);
108   EXPECT_EQ(N1->getInstruction(), S1);
109   // Check hasMemPred()
110   EXPECT_TRUE(N1->hasMemPred(N0));
111   EXPECT_FALSE(N0->hasMemPred(N1));
112 
113   // Check memPreds().
114   EXPECT_TRUE(N0->memPreds().empty());
115   EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
116   EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1));
117 }
118