xref: /llvm-project/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp (revision 7d17114c6b5f83ad3a58d2fb14068ea43738443c)
1 //===- Target/DirectX/PointerTypeAnalisis.cpp - PointerType analysis ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Analysis pass to assign types to opaque pointers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PointerTypeAnalysis.h"
14 #include "llvm/IR/Constants.h"
15 #include "llvm/IR/GlobalVariable.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/Module.h"
18 
19 using namespace llvm;
20 using namespace llvm::dxil;
21 
22 namespace {
23 
24 // Classifies the type of the value passed in by walking the value's users to
25 // find a typed instruction to materialize a type from.
classifyPointerType(const Value * V,PointerTypeMap & Map)26 Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
27   assert(V->getType()->isPointerTy() &&
28          "classifyPointerType called with non-pointer");
29   auto It = Map.find(V);
30   if (It != Map.end())
31     return It->second;
32 
33   Type *PointeeTy = nullptr;
34   if (auto *Inst = dyn_cast<GetElementPtrInst>(V)) {
35     if (!Inst->getResultElementType()->isPointerTy())
36       PointeeTy = Inst->getResultElementType();
37   } else if (auto *Inst = dyn_cast<AllocaInst>(V)) {
38     PointeeTy = Inst->getAllocatedType();
39   } else if (auto *GV = dyn_cast<GlobalVariable>(V)) {
40     PointeeTy = GV->getValueType();
41   }
42 
43   for (const auto *User : V->users()) {
44     Type *NewPointeeTy = nullptr;
45     if (const auto *Inst = dyn_cast<LoadInst>(User)) {
46       NewPointeeTy = Inst->getType();
47     } else if (const auto *Inst = dyn_cast<StoreInst>(User)) {
48       NewPointeeTy = Inst->getValueOperand()->getType();
49       // When store value is ptr type, cannot get more type info.
50       if (NewPointeeTy->isPointerTy())
51         continue;
52     } else if (const auto *Inst = dyn_cast<GetElementPtrInst>(User)) {
53       NewPointeeTy = Inst->getSourceElementType();
54     }
55     if (NewPointeeTy) {
56       // HLSL doesn't support pointers, so it is unlikely to get more than one
57       // or two levels of indirection in the IR. Because of this, recursion is
58       // pretty safe.
59       if (NewPointeeTy->isPointerTy()) {
60         PointeeTy = classifyPointerType(User, Map);
61         break;
62       }
63       if (!PointeeTy)
64         PointeeTy = NewPointeeTy;
65       else if (PointeeTy != NewPointeeTy)
66         PointeeTy = Type::getInt8Ty(V->getContext());
67     }
68   }
69   // If we were unable to determine the pointee type, set to i8
70   if (!PointeeTy)
71     PointeeTy = Type::getInt8Ty(V->getContext());
72   auto *TypedPtrTy =
73       TypedPointerType::get(PointeeTy, V->getType()->getPointerAddressSpace());
74 
75   Map[V] = TypedPtrTy;
76   return TypedPtrTy;
77 }
78 
79 // This function constructs a function type accepting typed pointers. It only
80 // handles function arguments and return types, and assigns the function type to
81 // the function's value in the type map.
classifyFunctionType(const Function & F,PointerTypeMap & Map)82 Type *classifyFunctionType(const Function &F, PointerTypeMap &Map) {
83   auto It = Map.find(&F);
84   if (It != Map.end())
85     return It->second;
86 
87   SmallVector<Type *, 8> NewArgs;
88   Type *RetTy = F.getReturnType();
89   LLVMContext &Ctx = F.getContext();
90   if (RetTy->isPointerTy()) {
91     RetTy = nullptr;
92     for (const auto &B : F) {
93       const auto *RetInst = dyn_cast_or_null<ReturnInst>(B.getTerminator());
94       if (!RetInst)
95         continue;
96 
97       Type *NewRetTy = classifyPointerType(RetInst->getReturnValue(), Map);
98       if (!RetTy)
99         RetTy = NewRetTy;
100       else if (RetTy != NewRetTy)
101         RetTy = TypedPointerType::get(
102             Type::getInt8Ty(Ctx), F.getReturnType()->getPointerAddressSpace());
103     }
104     // For function decl.
105     if (!RetTy)
106       RetTy = TypedPointerType::get(
107           Type::getInt8Ty(Ctx), F.getReturnType()->getPointerAddressSpace());
108   }
109   for (auto &A : F.args()) {
110     Type *ArgTy = A.getType();
111     if (ArgTy->isPointerTy())
112       ArgTy = classifyPointerType(&A, Map);
113     NewArgs.push_back(ArgTy);
114   }
115   auto *TypedPtrTy =
116       TypedPointerType::get(FunctionType::get(RetTy, NewArgs, false), 0);
117   Map[&F] = TypedPtrTy;
118   return TypedPtrTy;
119 }
120 } // anonymous namespace
121 
classifyConstantWithOpaquePtr(const Constant * C,PointerTypeMap & Map)122 static Type *classifyConstantWithOpaquePtr(const Constant *C,
123                                            PointerTypeMap &Map) {
124   // FIXME: support ConstantPointerNull which could map to more than one
125   // TypedPointerType.
126   // See https://github.com/llvm/llvm-project/issues/57942.
127   if (isa<ConstantPointerNull>(C))
128     return TypedPointerType::get(Type::getInt8Ty(C->getContext()),
129                                  C->getType()->getPointerAddressSpace());
130 
131   // Skip ConstantData which cannot have opaque ptr.
132   if (isa<ConstantData>(C))
133     return C->getType();
134 
135   auto It = Map.find(C);
136   if (It != Map.end())
137     return It->second;
138 
139   if (const auto *F = dyn_cast<Function>(C))
140     return classifyFunctionType(*F, Map);
141 
142   Type *Ty = C->getType();
143   Type *TargetTy = nullptr;
144   if (auto *CS = dyn_cast<ConstantStruct>(C)) {
145     SmallVector<Type *> EltTys;
146     for (unsigned int I = 0; I < CS->getNumOperands(); ++I) {
147       const Constant *Elt = C->getAggregateElement(I);
148       Type *EltTy = classifyConstantWithOpaquePtr(Elt, Map);
149       EltTys.emplace_back(EltTy);
150     }
151     TargetTy = StructType::get(C->getContext(), EltTys);
152   } else if (auto *CA = dyn_cast<ConstantAggregate>(C)) {
153 
154     Type *TargetEltTy = nullptr;
155     for (auto &Elt : CA->operands()) {
156       Type *EltTy = classifyConstantWithOpaquePtr(cast<Constant>(&Elt), Map);
157       assert(TargetEltTy == EltTy || TargetEltTy == nullptr);
158       TargetEltTy = EltTy;
159     }
160 
161     if (auto *AT = dyn_cast<ArrayType>(Ty)) {
162       TargetTy = ArrayType::get(TargetEltTy, AT->getNumElements());
163     } else {
164       // Not struct, not array, must be vector here.
165       auto *VT = cast<VectorType>(Ty);
166       TargetTy = VectorType::get(TargetEltTy, VT);
167     }
168   }
169   // Must have a target ty when map.
170   assert(TargetTy && "PointerTypeAnalyisis failed to identify target type");
171 
172   // Same type, no need to map.
173   if (TargetTy == Ty)
174     return Ty;
175 
176   Map[C] = TargetTy;
177   return TargetTy;
178 }
179 
classifyGlobalCtorPointerType(const GlobalVariable & GV,PointerTypeMap & Map)180 static void classifyGlobalCtorPointerType(const GlobalVariable &GV,
181                                           PointerTypeMap &Map) {
182   const auto *CA = cast<ConstantArray>(GV.getInitializer());
183   // Type for global ctor should be array of { i32, void ()*, i8* }.
184   Type *CtorArrayTy = classifyConstantWithOpaquePtr(CA, Map);
185 
186   // Map the global type.
187   Map[&GV] = TypedPointerType::get(CtorArrayTy,
188                                    GV.getType()->getPointerAddressSpace());
189 }
190 
run(const Module & M)191 PointerTypeMap PointerTypeAnalysis::run(const Module &M) {
192   PointerTypeMap Map;
193   for (auto &G : M.globals()) {
194     if (G.getType()->isPointerTy())
195       classifyPointerType(&G, Map);
196     if (G.getName() == "llvm.global_ctors")
197       classifyGlobalCtorPointerType(G, Map);
198   }
199 
200   for (auto &F : M) {
201     classifyFunctionType(F, Map);
202 
203     for (const auto &B : F) {
204       for (const auto &I : B) {
205         if (I.getType()->isPointerTy())
206           classifyPointerType(&I, Map);
207       }
208     }
209   }
210   return Map;
211 }
212