xref: /llvm-project/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp (revision 3f24561bc14fab4dbedd95955c45983197b659f3)
1 //===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DirectXIRPasses/PointerTypeAnalysis.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Instructions.h"
12 #include "llvm/IR/LLVMContext.h"
13 #include "llvm/IR/Module.h"
14 #include "llvm/IR/Type.h"
15 #include "llvm/IR/TypedPointerType.h"
16 #include "llvm/Support/SourceMgr.h"
17 
18 #include "gmock/gmock.h"
19 #include "gtest/gtest.h"
20 
21 using ::testing::Contains;
22 using ::testing::Pair;
23 
24 using namespace llvm;
25 using namespace llvm::dxil;
26 
27 template <typename T> struct IsA {
operator ==(const Value * V,const IsA &)28   friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
29 };
30 
TEST(PointerTypeAnalysis,DigressToi8)31 TEST(PointerTypeAnalysis, DigressToi8) {
32   StringRef Assembly = R"(
33     define i64 @test(ptr %p) {
34       store i32 0, ptr %p
35       %v = load i64, ptr %p
36       ret i64 %v
37     }
38   )";
39 
40   LLVMContext Context;
41   SMDiagnostic Error;
42   auto M = parseAssemblyString(Assembly, Error, Context);
43   ASSERT_TRUE(M) << "Bad assembly?";
44 
45   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
46   ASSERT_EQ(Map.size(), 2u);
47   Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
48   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
49 
50   EXPECT_THAT(Map,
51               Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
52   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
53 }
54 
TEST(PointerTypeAnalysis,DiscoverStore)55 TEST(PointerTypeAnalysis, DiscoverStore) {
56   StringRef Assembly = R"(
57     define i32 @test(ptr %p) {
58       store i32 0, ptr %p
59       ret i32 0
60     }
61   )";
62 
63   LLVMContext Context;
64   SMDiagnostic Error;
65   auto M = parseAssemblyString(Assembly, Error, Context);
66   ASSERT_TRUE(M) << "Bad assembly?";
67 
68   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
69   ASSERT_EQ(Map.size(), 2u);
70   Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
71   Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
72 
73   EXPECT_THAT(Map,
74               Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
75   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
76 }
77 
TEST(PointerTypeAnalysis,DiscoverLoad)78 TEST(PointerTypeAnalysis, DiscoverLoad) {
79   StringRef Assembly = R"(
80     define i32 @test(ptr %p) {
81       %v = load i32, ptr %p
82       ret i32 %v
83     }
84   )";
85 
86   LLVMContext Context;
87   SMDiagnostic Error;
88   auto M = parseAssemblyString(Assembly, Error, Context);
89   ASSERT_TRUE(M) << "Bad assembly?";
90 
91   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
92   ASSERT_EQ(Map.size(), 2u);
93   Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
94   Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
95 
96   EXPECT_THAT(Map,
97               Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
98   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
99 }
100 
TEST(PointerTypeAnalysis,DiscoverGEP)101 TEST(PointerTypeAnalysis, DiscoverGEP) {
102   StringRef Assembly = R"(
103     define ptr @test(ptr %p) {
104       %p2 = getelementptr i64, ptr %p, i64 1
105       ret ptr %p2
106     }
107   )";
108 
109   LLVMContext Context;
110   SMDiagnostic Error;
111   auto M = parseAssemblyString(Assembly, Error, Context);
112   ASSERT_TRUE(M) << "Bad assembly?";
113 
114   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
115   ASSERT_EQ(Map.size(), 3u);
116 
117   Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
118   Type *FnTy = FunctionType::get(I64Ptr, {I64Ptr}, false);
119 
120   EXPECT_THAT(Map,
121               Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
122   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64Ptr)));
123   EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
124 }
125 
TEST(PointerTypeAnalysis,TraceIndirect)126 TEST(PointerTypeAnalysis, TraceIndirect) {
127   StringRef Assembly = R"(
128     define i64 @test(ptr %p) {
129       %p2 = load ptr, ptr %p
130       %v = load i64, ptr %p2
131       ret i64 %v
132     }
133   )";
134 
135   LLVMContext Context;
136   SMDiagnostic Error;
137   auto M = parseAssemblyString(Assembly, Error, Context);
138   ASSERT_TRUE(M) << "Bad assembly?";
139 
140   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
141   ASSERT_EQ(Map.size(), 3u);
142 
143   Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
144   Type *I64PtrPtr = TypedPointerType::get(I64Ptr, 0);
145   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I64PtrPtr}, false);
146 
147   EXPECT_THAT(Map,
148               Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
149   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64PtrPtr)));
150   EXPECT_THAT(Map, Contains(Pair(IsA<LoadInst>(), I64Ptr)));
151 }
152 
TEST(PointerTypeAnalysis,WithNoOpCasts)153 TEST(PointerTypeAnalysis, WithNoOpCasts) {
154   StringRef Assembly = R"(
155     define i64 @test(ptr %p) {
156       %1 = bitcast ptr %p to ptr
157       %2 = bitcast ptr %p to ptr
158       store i32 0, ptr %1, align 4
159       %3 = load i64, ptr %2, align 8
160       ret i64 %3
161     }
162   )";
163 
164   LLVMContext Context;
165   SMDiagnostic Error;
166   auto M = parseAssemblyString(Assembly, Error, Context);
167   ASSERT_TRUE(M) << "Bad assembly?";
168 
169   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
170   ASSERT_EQ(Map.size(), 4u);
171 
172   Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
173   Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
174   Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
175   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
176 
177   EXPECT_THAT(Map,
178               Contains(Pair(IsA<Function>(), TypedPointerType::get(FnTy, 0))));
179   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
180   EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I64Ptr)));
181   EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I32Ptr)));
182 }
183