1 //===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "DirectXIRPasses/PointerTypeAnalysis.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Instructions.h"
12 #include "llvm/IR/LLVMContext.h"
13 #include "llvm/IR/Module.h"
14 #include "llvm/IR/Type.h"
15 #include "llvm/IR/TypedPointerType.h"
16 #include "llvm/Support/SourceMgr.h"
17
18 #include "gmock/gmock.h"
19 #include "gtest/gtest.h"
20
21 using ::testing::Contains;
22 using ::testing::Pair;
23
24 using namespace llvm;
25 using namespace llvm::dxil;
26
27 template <typename T> struct IsA {
operator ==(const Value * V,const IsA &)28 friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
29 };
30
TEST(PointerTypeAnalysis,DigressToi8)31 TEST(PointerTypeAnalysis, DigressToi8) {
32 StringRef Assembly = R"(
33 define i64 @test(ptr %p) {
34 store i32 0, ptr %p
35 %v = load i64, ptr %p
36 ret i64 %v
37 }
38 )";
39
40 LLVMContext Context;
41 SMDiagnostic Error;
42 auto M = parseAssemblyString(Assembly, Error, Context);
43 ASSERT_TRUE(M) << "Bad assembly?";
44
45 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
46 ASSERT_EQ(Map.size(), 2u);
47 Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
48 Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
49
50 EXPECT_THAT(Map,
51 Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
52 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
53 }
54
TEST(PointerTypeAnalysis,DiscoverStore)55 TEST(PointerTypeAnalysis, DiscoverStore) {
56 StringRef Assembly = R"(
57 define i32 @test(ptr %p) {
58 store i32 0, ptr %p
59 ret i32 0
60 }
61 )";
62
63 LLVMContext Context;
64 SMDiagnostic Error;
65 auto M = parseAssemblyString(Assembly, Error, Context);
66 ASSERT_TRUE(M) << "Bad assembly?";
67
68 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
69 ASSERT_EQ(Map.size(), 2u);
70 Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
71 Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
72
73 EXPECT_THAT(Map,
74 Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
75 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
76 }
77
TEST(PointerTypeAnalysis,DiscoverLoad)78 TEST(PointerTypeAnalysis, DiscoverLoad) {
79 StringRef Assembly = R"(
80 define i32 @test(ptr %p) {
81 %v = load i32, ptr %p
82 ret i32 %v
83 }
84 )";
85
86 LLVMContext Context;
87 SMDiagnostic Error;
88 auto M = parseAssemblyString(Assembly, Error, Context);
89 ASSERT_TRUE(M) << "Bad assembly?";
90
91 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
92 ASSERT_EQ(Map.size(), 2u);
93 Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
94 Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
95
96 EXPECT_THAT(Map,
97 Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
98 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
99 }
100
TEST(PointerTypeAnalysis,DiscoverGEP)101 TEST(PointerTypeAnalysis, DiscoverGEP) {
102 StringRef Assembly = R"(
103 define ptr @test(ptr %p) {
104 %p2 = getelementptr i64, ptr %p, i64 1
105 ret ptr %p2
106 }
107 )";
108
109 LLVMContext Context;
110 SMDiagnostic Error;
111 auto M = parseAssemblyString(Assembly, Error, Context);
112 ASSERT_TRUE(M) << "Bad assembly?";
113
114 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
115 ASSERT_EQ(Map.size(), 3u);
116
117 Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
118 Type *FnTy = FunctionType::get(I64Ptr, {I64Ptr}, false);
119
120 EXPECT_THAT(Map,
121 Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
122 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64Ptr)));
123 EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
124 }
125
TEST(PointerTypeAnalysis,TraceIndirect)126 TEST(PointerTypeAnalysis, TraceIndirect) {
127 StringRef Assembly = R"(
128 define i64 @test(ptr %p) {
129 %p2 = load ptr, ptr %p
130 %v = load i64, ptr %p2
131 ret i64 %v
132 }
133 )";
134
135 LLVMContext Context;
136 SMDiagnostic Error;
137 auto M = parseAssemblyString(Assembly, Error, Context);
138 ASSERT_TRUE(M) << "Bad assembly?";
139
140 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
141 ASSERT_EQ(Map.size(), 3u);
142
143 Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
144 Type *I64PtrPtr = TypedPointerType::get(I64Ptr, 0);
145 Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I64PtrPtr}, false);
146
147 EXPECT_THAT(Map,
148 Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
149 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64PtrPtr)));
150 EXPECT_THAT(Map, Contains(Pair(IsA<LoadInst>(), I64Ptr)));
151 }
152
TEST(PointerTypeAnalysis,WithNoOpCasts)153 TEST(PointerTypeAnalysis, WithNoOpCasts) {
154 StringRef Assembly = R"(
155 define i64 @test(ptr %p) {
156 %1 = bitcast ptr %p to ptr
157 %2 = bitcast ptr %p to ptr
158 store i32 0, ptr %1, align 4
159 %3 = load i64, ptr %2, align 8
160 ret i64 %3
161 }
162 )";
163
164 LLVMContext Context;
165 SMDiagnostic Error;
166 auto M = parseAssemblyString(Assembly, Error, Context);
167 ASSERT_TRUE(M) << "Bad assembly?";
168
169 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
170 ASSERT_EQ(Map.size(), 4u);
171
172 Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
173 Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
174 Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
175 Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
176
177 EXPECT_THAT(Map,
178 Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
179 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
180 EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I64Ptr)));
181 EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I32Ptr)));
182 }
183