xref: /llvm-project/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp (revision 358d65463b215a18e731b3a5494d51e1bcbd1356)
1 //===- InstrMapsTest.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
10 #include "llvm/ADT/SmallSet.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/Support/SourceMgr.h"
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
17 
18 using namespace llvm;
19 
20 struct InstrMapsTest : public testing::Test {
21   LLVMContext C;
22   std::unique_ptr<Module> M;
23 
24   void parseIR(LLVMContext &C, const char *IR) {
25     SMDiagnostic Err;
26     M = parseAssemblyString(IR, Err, C);
27     if (!M)
28       Err.print("InstrMapsTest", errs());
29   }
30 };
31 
32 TEST_F(InstrMapsTest, Basic) {
33   parseIR(C, R"IR(
34 define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
35   %add0 = add i8 %v0, %v0
36   %add1 = add i8 %v1, %v1
37   %add2 = add i8 %v2, %v2
38   %add3 = add i8 %v3, %v3
39   %vadd0 = add <2 x i8> %vec, %vec
40   ret void
41 }
42 )IR");
43   llvm::Function *LLVMF = &*M->getFunction("foo");
44   sandboxir::Context Ctx(C);
45   auto *F = Ctx.createFunction(LLVMF);
46   auto *BB = &*F->begin();
47   auto It = BB->begin();
48 
49   auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
50   auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
51   auto *Add2 = cast<sandboxir::BinaryOperator>(&*It++);
52   auto *Add3 = cast<sandboxir::BinaryOperator>(&*It++);
53   auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
54   [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
55 
56   sandboxir::InstrMaps IMaps(Ctx);
57   // Check with empty IMaps.
58   EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
59   EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
60   EXPECT_FALSE(IMaps.getOrigLane(Add0, Add0));
61   // Check with 1 match.
62   IMaps.registerVector({Add0, Add1}, VAdd0);
63   EXPECT_EQ(IMaps.getVectorForOrig(Add0), VAdd0);
64   EXPECT_EQ(IMaps.getVectorForOrig(Add1), VAdd0);
65   EXPECT_FALSE(IMaps.getOrigLane(VAdd0, VAdd0)); // Bad Orig value
66   EXPECT_FALSE(IMaps.getOrigLane(Add0, Add0));   // Bad Vector value
67   EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add0), 0U);
68   EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1U);
69   // Check when the same vector maps to different original values (which is
70   // common for vector constants).
71   IMaps.registerVector({Add2, Add3}, VAdd0);
72   EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add2), 0U);
73   EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add3), 1U);
74   // Check when we register for a second time.
75 #ifndef NDEBUG
76   EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
77 #endif // NDEBUG
78   // Check callbacks: erase original instr.
79   Add0->eraseFromParent();
80   EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
81   EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1U);
82   EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
83   // Check callbacks: erase vector instr.
84   VAdd0->eraseFromParent();
85   EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
86   EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
87 }
88