xref: /llvm-project/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp (revision 048cf8857e081fb80d5ac8b24a79f999d632141b)
1 //===- ModuleUtilsTest.cpp - Unit tests for Module utility ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/ModuleUtils.h"
10 #include "llvm/ADT/StringRef.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/Constants.h"
13 #include "llvm/IR/LLVMContext.h"
14 #include "llvm/IR/Module.h"
15 #include "llvm/Support/SourceMgr.h"
16 #include "gtest/gtest.h"
17 
18 using namespace llvm;
19 
20 static std::unique_ptr<Module> parseIR(LLVMContext &C, StringRef IR) {
21   SMDiagnostic Err;
22   std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
23   if (!Mod)
24     Err.print("ModuleUtilsTest", errs());
25   return Mod;
26 }
27 
28 static int getListSize(Module &M, StringRef Name) {
29   auto *List = M.getGlobalVariable(Name);
30   if (!List)
31     return 0;
32   auto *T = cast<ArrayType>(List->getValueType());
33   return T->getNumElements();
34 }
35 
36 TEST(ModuleUtils, AppendToUsedList1) {
37   LLVMContext C;
38 
39   std::unique_ptr<Module> M = parseIR(
40       C, R"(@x = addrspace(4) global [2 x i32] zeroinitializer, align 4)");
41   SmallVector<GlobalValue *, 2> Globals;
42   for (auto &G : M->globals()) {
43     Globals.push_back(&G);
44   }
45   EXPECT_EQ(0, getListSize(*M, "llvm.compiler.used"));
46   appendToCompilerUsed(*M, Globals);
47   EXPECT_EQ(1, getListSize(*M, "llvm.compiler.used"));
48 
49   EXPECT_EQ(0, getListSize(*M, "llvm.used"));
50   appendToUsed(*M, Globals);
51   EXPECT_EQ(1, getListSize(*M, "llvm.used"));
52 }
53 
54 TEST(ModuleUtils, AppendToUsedList2) {
55   LLVMContext C;
56 
57   std::unique_ptr<Module> M =
58       parseIR(C, R"(@x = global [2 x i32] zeroinitializer, align 4)");
59   SmallVector<GlobalValue *, 2> Globals;
60   for (auto &G : M->globals()) {
61     Globals.push_back(&G);
62   }
63   EXPECT_EQ(0, getListSize(*M, "llvm.compiler.used"));
64   appendToCompilerUsed(*M, Globals);
65   EXPECT_EQ(1, getListSize(*M, "llvm.compiler.used"));
66 
67   EXPECT_EQ(0, getListSize(*M, "llvm.used"));
68   appendToUsed(*M, Globals);
69   EXPECT_EQ(1, getListSize(*M, "llvm.used"));
70 }
71 
72 using AppendFnType = decltype(&appendToGlobalCtors);
73 using TransformFnType = decltype(&transformGlobalCtors);
74 using ParamType = std::tuple<StringRef, AppendFnType, TransformFnType>;
75 class ModuleUtilsTest : public testing::TestWithParam<ParamType> {
76 public:
77   StringRef arrayName() const { return std::get<0>(GetParam()); }
78   AppendFnType appendFn() const { return std::get<AppendFnType>(GetParam()); }
79   TransformFnType transformFn() const {
80     return std::get<TransformFnType>(GetParam());
81   }
82 };
83 
84 INSTANTIATE_TEST_SUITE_P(
85     ModuleUtilsTestCtors, ModuleUtilsTest,
86     ::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors,
87                                 &transformGlobalCtors},
88                       ParamType{"llvm.global_dtors", &appendToGlobalDtors,
89                                 &transformGlobalDtors}));
90 
91 TEST_P(ModuleUtilsTest, AppendToMissingArray) {
92   LLVMContext C;
93 
94   std::unique_ptr<Module> M = parseIR(C, "");
95 
96   EXPECT_EQ(0, getListSize(*M, arrayName()));
97   Function *F = cast<Function>(
98       M->getOrInsertFunction("ctor", Type::getVoidTy(C)).getCallee());
99   appendFn()(*M, F, 11, F);
100   ASSERT_EQ(1, getListSize(*M, arrayName()));
101 
102   ConstantArray *CA = dyn_cast<ConstantArray>(
103       M->getGlobalVariable(arrayName())->getInitializer());
104   ASSERT_NE(nullptr, CA);
105   ConstantStruct *CS = dyn_cast<ConstantStruct>(CA->getOperand(0));
106   ASSERT_NE(nullptr, CS);
107   ConstantInt *Pri = dyn_cast<ConstantInt>(CS->getOperand(0));
108   ASSERT_NE(nullptr, Pri);
109   EXPECT_EQ(11u, Pri->getLimitedValue());
110   EXPECT_EQ(F, dyn_cast<Function>(CS->getOperand(1)));
111   EXPECT_EQ(F, CS->getOperand(2));
112 }
113 
114 TEST_P(ModuleUtilsTest, AppendToArray) {
115   LLVMContext C;
116 
117   std::unique_ptr<Module> M =
118       parseIR(C, (R"(@)" + arrayName() +
119                   R"( = appending global [2 x { i32, ptr, ptr }] [
120             { i32, ptr, ptr } { i32 65535, ptr  null, ptr null },
121             { i32, ptr, ptr } { i32 0, ptr  null, ptr null }]
122       )")
123                      .str());
124 
125   EXPECT_EQ(2, getListSize(*M, arrayName()));
126   appendFn()(
127       *M,
128       cast<Function>(
129           M->getOrInsertFunction("ctor", Type::getVoidTy(C)).getCallee()),
130       11, nullptr);
131   EXPECT_EQ(3, getListSize(*M, arrayName()));
132 }
133 
134 TEST_P(ModuleUtilsTest, UpdateArray) {
135   LLVMContext C;
136 
137   std::unique_ptr<Module> M =
138       parseIR(C, (R"(@)" + arrayName() +
139                   R"( = appending global [2 x { i32, ptr, ptr }] [
140             { i32, ptr, ptr } { i32 65535, ptr  null, ptr null },
141             { i32, ptr, ptr } { i32 0, ptr  null, ptr null }]
142       )")
143                      .str());
144 
145   EXPECT_EQ(2, getListSize(*M, arrayName()));
146   transformFn()(*M, [](Constant *C) -> Constant * {
147     ConstantStruct *CS = dyn_cast<ConstantStruct>(C);
148     if (!CS)
149       return nullptr;
150     StructType *EltTy = cast<StructType>(C->getType());
151     Constant *CSVals[3] = {
152         ConstantInt::getSigned(CS->getOperand(0)->getType(), 12),
153         CS->getOperand(1),
154         CS->getOperand(2),
155     };
156     return ConstantStruct::get(EltTy,
157                                ArrayRef(CSVals, EltTy->getNumElements()));
158   });
159   EXPECT_EQ(1, getListSize(*M, arrayName()));
160   ConstantArray *CA = dyn_cast<ConstantArray>(
161       M->getGlobalVariable(arrayName())->getInitializer());
162   ASSERT_NE(nullptr, CA);
163   ConstantStruct *CS = dyn_cast<ConstantStruct>(CA->getOperand(0));
164   ASSERT_NE(nullptr, CS);
165   ConstantInt *Pri = dyn_cast<ConstantInt>(CS->getOperand(0));
166   ASSERT_NE(nullptr, Pri);
167   EXPECT_EQ(12u, Pri->getLimitedValue());
168 }
169