xref: /llvm-project/clang/lib/AST/ByteCode/Context.cpp (revision 2c934dc5e1a3ef7b717400f27d6b9ea21f4e20a0)
1 //===--- Context.cpp - Context for the constexpr VM -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Context.h"
10 #include "ByteCodeEmitter.h"
11 #include "Compiler.h"
12 #include "EvalEmitter.h"
13 #include "Interp.h"
14 #include "InterpFrame.h"
15 #include "InterpStack.h"
16 #include "PrimType.h"
17 #include "Program.h"
18 #include "clang/AST/Expr.h"
19 #include "clang/Basic/TargetInfo.h"
20 
21 using namespace clang;
22 using namespace clang::interp;
23 
24 Context::Context(ASTContext &Ctx) : Ctx(Ctx), P(new Program(*this)) {}
25 
26 Context::~Context() {}
27 
28 bool Context::isPotentialConstantExpr(State &Parent, const FunctionDecl *FD) {
29   assert(Stk.empty());
30   const Function *Func = getOrCreateFunction(FD);
31   if (!Func)
32     return false;
33 
34   if (!Run(Parent, Func))
35     return false;
36 
37   return Func->isConstexpr();
38 }
39 
40 bool Context::evaluateAsRValue(State &Parent, const Expr *E, APValue &Result) {
41   ++EvalID;
42   bool Recursing = !Stk.empty();
43   size_t StackSizeBefore = Stk.size();
44   Compiler<EvalEmitter> C(*this, *P, Parent, Stk);
45 
46   auto Res = C.interpretExpr(E, /*ConvertResultToRValue=*/E->isGLValue());
47 
48   if (Res.isInvalid()) {
49     C.cleanup();
50     Stk.clearTo(StackSizeBefore);
51     return false;
52   }
53 
54   if (!Recursing) {
55     assert(Stk.empty());
56     C.cleanup();
57 #ifndef NDEBUG
58     // Make sure we don't rely on some value being still alive in
59     // InterpStack memory.
60     Stk.clearTo(StackSizeBefore);
61 #endif
62   }
63 
64   Result = Res.toAPValue();
65 
66   return true;
67 }
68 
69 bool Context::evaluate(State &Parent, const Expr *E, APValue &Result,
70                        ConstantExprKind Kind) {
71   ++EvalID;
72   bool Recursing = !Stk.empty();
73   size_t StackSizeBefore = Stk.size();
74   Compiler<EvalEmitter> C(*this, *P, Parent, Stk);
75 
76   auto Res = C.interpretExpr(E, /*ConvertResultToRValue=*/false,
77                              /*DestroyToplevelScope=*/true);
78   if (Res.isInvalid()) {
79     C.cleanup();
80     Stk.clearTo(StackSizeBefore);
81     return false;
82   }
83 
84   if (!Recursing) {
85     assert(Stk.empty());
86     C.cleanup();
87 #ifndef NDEBUG
88     // Make sure we don't rely on some value being still alive in
89     // InterpStack memory.
90     Stk.clearTo(StackSizeBefore);
91 #endif
92   }
93 
94   Result = Res.toAPValue();
95   return true;
96 }
97 
98 bool Context::evaluateAsInitializer(State &Parent, const VarDecl *VD,
99                                     APValue &Result) {
100   ++EvalID;
101   bool Recursing = !Stk.empty();
102   size_t StackSizeBefore = Stk.size();
103   Compiler<EvalEmitter> C(*this, *P, Parent, Stk);
104 
105   bool CheckGlobalInitialized =
106       shouldBeGloballyIndexed(VD) &&
107       (VD->getType()->isRecordType() || VD->getType()->isArrayType());
108   auto Res = C.interpretDecl(VD, CheckGlobalInitialized);
109   if (Res.isInvalid()) {
110     C.cleanup();
111     Stk.clearTo(StackSizeBefore);
112 
113     return false;
114   }
115 
116   if (!Recursing) {
117     assert(Stk.empty());
118     C.cleanup();
119 #ifndef NDEBUG
120     // Make sure we don't rely on some value being still alive in
121     // InterpStack memory.
122     Stk.clearTo(StackSizeBefore);
123 #endif
124   }
125 
126   Result = Res.toAPValue();
127   return true;
128 }
129 
130 const LangOptions &Context::getLangOpts() const { return Ctx.getLangOpts(); }
131 
132 std::optional<PrimType> Context::classify(QualType T) const {
133   if (T->isBooleanType())
134     return PT_Bool;
135 
136   // We map these to primitive arrays.
137   if (T->isAnyComplexType() || T->isVectorType())
138     return std::nullopt;
139 
140   if (T->isSignedIntegerOrEnumerationType()) {
141     switch (Ctx.getIntWidth(T)) {
142     case 64:
143       return PT_Sint64;
144     case 32:
145       return PT_Sint32;
146     case 16:
147       return PT_Sint16;
148     case 8:
149       return PT_Sint8;
150     default:
151       return PT_IntAPS;
152     }
153   }
154 
155   if (T->isUnsignedIntegerOrEnumerationType()) {
156     switch (Ctx.getIntWidth(T)) {
157     case 64:
158       return PT_Uint64;
159     case 32:
160       return PT_Uint32;
161     case 16:
162       return PT_Uint16;
163     case 8:
164       return PT_Uint8;
165     case 1:
166       // Might happen for enum types.
167       return PT_Bool;
168     default:
169       return PT_IntAP;
170     }
171   }
172 
173   if (T->isNullPtrType())
174     return PT_Ptr;
175 
176   if (T->isFloatingType())
177     return PT_Float;
178 
179   if (T->isSpecificBuiltinType(BuiltinType::BoundMember) ||
180       T->isMemberPointerType())
181     return PT_MemberPtr;
182 
183   if (T->isFunctionPointerType() || T->isFunctionReferenceType() ||
184       T->isFunctionType() || T->isBlockPointerType())
185     return PT_FnPtr;
186 
187   if (T->isPointerOrReferenceType() || T->isObjCObjectPointerType())
188     return PT_Ptr;
189 
190   if (const auto *AT = T->getAs<AtomicType>())
191     return classify(AT->getValueType());
192 
193   if (const auto *DT = dyn_cast<DecltypeType>(T))
194     return classify(DT->getUnderlyingType());
195 
196   if (T->isFixedPointType())
197     return PT_FixedPoint;
198 
199   return std::nullopt;
200 }
201 
202 unsigned Context::getCharBit() const {
203   return Ctx.getTargetInfo().getCharWidth();
204 }
205 
206 /// Simple wrapper around getFloatTypeSemantics() to make code a
207 /// little shorter.
208 const llvm::fltSemantics &Context::getFloatSemantics(QualType T) const {
209   return Ctx.getFloatTypeSemantics(T);
210 }
211 
212 bool Context::Run(State &Parent, const Function *Func) {
213 
214   {
215     InterpState State(Parent, *P, Stk, *this);
216     State.Current = new InterpFrame(State, Func, /*Caller=*/nullptr, CodePtr(),
217                                     Func->getArgSize());
218     if (Interpret(State)) {
219       assert(Stk.empty());
220       return true;
221     }
222 
223     // State gets destroyed here, so the Stk.clear() below doesn't accidentally
224     // remove values the State's destructor might access.
225   }
226 
227   Stk.clear();
228   return false;
229 }
230 
231 // TODO: Virtual bases?
232 const CXXMethodDecl *
233 Context::getOverridingFunction(const CXXRecordDecl *DynamicDecl,
234                                const CXXRecordDecl *StaticDecl,
235                                const CXXMethodDecl *InitialFunction) const {
236   assert(DynamicDecl);
237   assert(StaticDecl);
238   assert(InitialFunction);
239 
240   const CXXRecordDecl *CurRecord = DynamicDecl;
241   const CXXMethodDecl *FoundFunction = InitialFunction;
242   for (;;) {
243     const CXXMethodDecl *Overrider =
244         FoundFunction->getCorrespondingMethodDeclaredInClass(CurRecord, false);
245     if (Overrider)
246       return Overrider;
247 
248     // Common case of only one base class.
249     if (CurRecord->getNumBases() == 1) {
250       CurRecord = CurRecord->bases_begin()->getType()->getAsCXXRecordDecl();
251       continue;
252     }
253 
254     // Otherwise, go to the base class that will lead to the StaticDecl.
255     for (const CXXBaseSpecifier &Spec : CurRecord->bases()) {
256       const CXXRecordDecl *Base = Spec.getType()->getAsCXXRecordDecl();
257       if (Base == StaticDecl || Base->isDerivedFrom(StaticDecl)) {
258         CurRecord = Base;
259         break;
260       }
261     }
262   }
263 
264   llvm_unreachable(
265       "Couldn't find an overriding function in the class hierarchy?");
266   return nullptr;
267 }
268 
269 const Function *Context::getOrCreateFunction(const FunctionDecl *FD) {
270   assert(FD);
271   FD = FD->getMostRecentDecl();
272   const Function *Func = P->getFunction(FD);
273   bool IsBeingCompiled = Func && Func->isDefined() && !Func->isFullyCompiled();
274   bool WasNotDefined = Func && !Func->isConstexpr() && !Func->isDefined();
275 
276   if (IsBeingCompiled)
277     return Func;
278 
279   if (!Func || WasNotDefined) {
280     if (auto F = Compiler<ByteCodeEmitter>(*this, *P).compileFunc(FD))
281       Func = F;
282   }
283 
284   return Func;
285 }
286 
287 unsigned Context::collectBaseOffset(const RecordDecl *BaseDecl,
288                                     const RecordDecl *DerivedDecl) const {
289   assert(BaseDecl);
290   assert(DerivedDecl);
291   const auto *FinalDecl = cast<CXXRecordDecl>(BaseDecl);
292   const RecordDecl *CurDecl = DerivedDecl;
293   const Record *CurRecord = P->getOrCreateRecord(CurDecl);
294   assert(CurDecl && FinalDecl);
295 
296   unsigned OffsetSum = 0;
297   for (;;) {
298     assert(CurRecord->getNumBases() > 0);
299     // One level up
300     for (const Record::Base &B : CurRecord->bases()) {
301       const auto *BaseDecl = cast<CXXRecordDecl>(B.Decl);
302 
303       if (BaseDecl == FinalDecl || BaseDecl->isDerivedFrom(FinalDecl)) {
304         OffsetSum += B.Offset;
305         CurRecord = B.R;
306         CurDecl = BaseDecl;
307         break;
308       }
309     }
310     if (CurDecl == FinalDecl)
311       break;
312   }
313 
314   assert(OffsetSum > 0);
315   return OffsetSum;
316 }
317 
318 const Record *Context::getRecord(const RecordDecl *D) const {
319   return P->getOrCreateRecord(D);
320 }
321