xref: /llvm-project/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (revision 0c9f7ef52739b28f42c03c2bd1c87b744b687e6f)
1 //===- DependencyGraphTest.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/SandboxIR/SandboxIR.h"
12 #include "llvm/Support/SourceMgr.h"
13 #include "gmock/gmock-matchers.h"
14 #include "gtest/gtest.h"
15 
16 using namespace llvm;
17 
18 struct DependencyGraphTest : public testing::Test {
19   LLVMContext C;
20   std::unique_ptr<Module> M;
21 
22   void parseIR(LLVMContext &C, const char *IR) {
23     SMDiagnostic Err;
24     M = parseAssemblyString(IR, Err, C);
25     if (!M)
26       Err.print("DependencyGraphTest", errs());
27   }
28 };
29 
30 TEST_F(DependencyGraphTest, Basic) {
31   parseIR(C, R"IR(
32 define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
33   store i8 %v0, ptr %ptr
34   store i8 %v1, ptr %ptr
35   ret void
36 }
37 )IR");
38   llvm::Function *LLVMF = &*M->getFunction("foo");
39   sandboxir::Context Ctx(C);
40   auto *F = Ctx.createFunction(LLVMF);
41   auto *BB = &*F->begin();
42   auto It = BB->begin();
43   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
44   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
45   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
46   sandboxir::DependencyGraph DAG;
47   auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()});
48   // Check extend().
49   EXPECT_EQ(Span.top(), &*BB->begin());
50   EXPECT_EQ(Span.bottom(), BB->getTerminator());
51 
52   sandboxir::DGNode *N0 = DAG.getNode(S0);
53   sandboxir::DGNode *N1 = DAG.getNode(S1);
54   sandboxir::DGNode *N2 = DAG.getNode(Ret);
55   // Check getInstruction().
56   EXPECT_EQ(N0->getInstruction(), S0);
57   EXPECT_EQ(N1->getInstruction(), S1);
58   // Check hasMemPred()
59   EXPECT_TRUE(N1->hasMemPred(N0));
60   EXPECT_FALSE(N0->hasMemPred(N1));
61 
62   // Check memPreds().
63   EXPECT_TRUE(N0->memPreds().empty());
64   EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
65   EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1));
66 }
67