1 //===- DependencyGraphTest.cpp --------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/SandboxIR/SandboxIR.h" 12 #include "llvm/Support/SourceMgr.h" 13 #include "gmock/gmock-matchers.h" 14 #include "gtest/gtest.h" 15 16 using namespace llvm; 17 18 struct DependencyGraphTest : public testing::Test { 19 LLVMContext C; 20 std::unique_ptr<Module> M; 21 22 void parseIR(LLVMContext &C, const char *IR) { 23 SMDiagnostic Err; 24 M = parseAssemblyString(IR, Err, C); 25 if (!M) 26 Err.print("DependencyGraphTest", errs()); 27 } 28 }; 29 30 TEST_F(DependencyGraphTest, DGNode_IsMem) { 31 parseIR(C, R"IR( 32 declare void @llvm.sideeffect() 33 declare void @llvm.pseudoprobe(i64, i64, i32, i64) 34 declare void @llvm.fake.use(...) 35 declare void @bar() 36 define void @foo(i8 %v1, ptr %ptr) { 37 store i8 %v1, ptr %ptr 38 %ld0 = load i8, ptr %ptr 39 %add = add i8 %v1, %v1 40 %stacksave = call ptr @llvm.stacksave() 41 call void @llvm.stackrestore(ptr %stacksave) 42 call void @llvm.sideeffect() 43 call void @llvm.pseudoprobe(i64 42, i64 1, i32 0, i64 -1) 44 call void @llvm.fake.use(ptr %ptr) 45 call void @bar() 46 ret void 47 } 48 )IR"); 49 llvm::Function *LLVMF = &*M->getFunction("foo"); 50 sandboxir::Context Ctx(C); 51 auto *F = Ctx.createFunction(LLVMF); 52 auto *BB = &*F->begin(); 53 auto It = BB->begin(); 54 auto *Store = cast<sandboxir::StoreInst>(&*It++); 55 auto *Load = cast<sandboxir::LoadInst>(&*It++); 56 auto *Add = cast<sandboxir::BinaryOperator>(&*It++); 57 auto *StackSave = cast<sandboxir::CallInst>(&*It++); 58 auto *StackRestore = cast<sandboxir::CallInst>(&*It++); 59 auto *SideEffect = cast<sandboxir::CallInst>(&*It++); 60 auto *PseudoProbe = cast<sandboxir::CallInst>(&*It++); 61 auto *FakeUse = cast<sandboxir::CallInst>(&*It++); 62 auto *Call = cast<sandboxir::CallInst>(&*It++); 63 auto *Ret = cast<sandboxir::ReturnInst>(&*It++); 64 65 sandboxir::DependencyGraph DAG; 66 DAG.extend({&*BB->begin(), BB->getTerminator()}); 67 EXPECT_TRUE(DAG.getNode(Store)->isMem()); 68 EXPECT_TRUE(DAG.getNode(Load)->isMem()); 69 EXPECT_FALSE(DAG.getNode(Add)->isMem()); 70 EXPECT_TRUE(DAG.getNode(StackSave)->isMem()); 71 EXPECT_TRUE(DAG.getNode(StackRestore)->isMem()); 72 EXPECT_FALSE(DAG.getNode(SideEffect)->isMem()); 73 EXPECT_FALSE(DAG.getNode(PseudoProbe)->isMem()); 74 EXPECT_TRUE(DAG.getNode(FakeUse)->isMem()); 75 EXPECT_TRUE(DAG.getNode(Call)->isMem()); 76 EXPECT_FALSE(DAG.getNode(Ret)->isMem()); 77 } 78 79 TEST_F(DependencyGraphTest, Basic) { 80 parseIR(C, R"IR( 81 define void @foo(ptr %ptr, i8 %v0, i8 %v1) { 82 store i8 %v0, ptr %ptr 83 store i8 %v1, ptr %ptr 84 ret void 85 } 86 )IR"); 87 llvm::Function *LLVMF = &*M->getFunction("foo"); 88 sandboxir::Context Ctx(C); 89 auto *F = Ctx.createFunction(LLVMF); 90 auto *BB = &*F->begin(); 91 auto It = BB->begin(); 92 auto *S0 = cast<sandboxir::StoreInst>(&*It++); 93 auto *S1 = cast<sandboxir::StoreInst>(&*It++); 94 auto *Ret = cast<sandboxir::ReturnInst>(&*It++); 95 sandboxir::DependencyGraph DAG; 96 auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()}); 97 // Check extend(). 98 EXPECT_EQ(Span.top(), &*BB->begin()); 99 EXPECT_EQ(Span.bottom(), BB->getTerminator()); 100 101 sandboxir::DGNode *N0 = DAG.getNode(S0); 102 sandboxir::DGNode *N1 = DAG.getNode(S1); 103 sandboxir::DGNode *N2 = DAG.getNode(Ret); 104 // Check getInstruction(). 105 EXPECT_EQ(N0->getInstruction(), S0); 106 EXPECT_EQ(N1->getInstruction(), S1); 107 // Check hasMemPred() 108 EXPECT_TRUE(N1->hasMemPred(N0)); 109 EXPECT_FALSE(N0->hasMemPred(N1)); 110 111 // Check memPreds(). 112 EXPECT_TRUE(N0->memPreds().empty()); 113 EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0)); 114 EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1)); 115 } 116