xref: /llvm-project/clang/lib/AST/ByteCode/Context.cpp (revision 83fea8b809b284594e6dd133150bb6d365775e5b)
1 //===--- Context.cpp - Context for the constexpr VM -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Context.h"
10 #include "ByteCodeEmitter.h"
11 #include "Compiler.h"
12 #include "EvalEmitter.h"
13 #include "Interp.h"
14 #include "InterpFrame.h"
15 #include "InterpStack.h"
16 #include "PrimType.h"
17 #include "Program.h"
18 #include "clang/AST/Expr.h"
19 #include "clang/Basic/TargetInfo.h"
20 
21 using namespace clang;
22 using namespace clang::interp;
23 
24 Context::Context(ASTContext &Ctx) : Ctx(Ctx), P(new Program(*this)) {}
25 
26 Context::~Context() {}
27 
28 bool Context::isPotentialConstantExpr(State &Parent, const FunctionDecl *FD) {
29   assert(Stk.empty());
30   Function *Func = P->getFunction(FD);
31   if (!Func || !Func->hasBody())
32     Func = Compiler<ByteCodeEmitter>(*this, *P).compileFunc(FD);
33 
34   if (!Func)
35     return false;
36 
37   APValue DummyResult;
38   if (!Run(Parent, Func, DummyResult))
39     return false;
40 
41   return Func->isConstexpr();
42 }
43 
44 bool Context::evaluateAsRValue(State &Parent, const Expr *E, APValue &Result) {
45   ++EvalID;
46   bool Recursing = !Stk.empty();
47   size_t StackSizeBefore = Stk.size();
48   Compiler<EvalEmitter> C(*this, *P, Parent, Stk);
49 
50   auto Res = C.interpretExpr(E, /*ConvertResultToRValue=*/E->isGLValue());
51 
52   if (Res.isInvalid()) {
53     C.cleanup();
54     Stk.clearTo(StackSizeBefore);
55     return false;
56   }
57 
58   if (!Recursing) {
59     assert(Stk.empty());
60     C.cleanup();
61 #ifndef NDEBUG
62     // Make sure we don't rely on some value being still alive in
63     // InterpStack memory.
64     Stk.clearTo(StackSizeBefore);
65 #endif
66   }
67 
68   Result = Res.toAPValue();
69 
70   return true;
71 }
72 
73 bool Context::evaluate(State &Parent, const Expr *E, APValue &Result,
74                        ConstantExprKind Kind) {
75   ++EvalID;
76   bool Recursing = !Stk.empty();
77   size_t StackSizeBefore = Stk.size();
78   Compiler<EvalEmitter> C(*this, *P, Parent, Stk);
79 
80   auto Res = C.interpretExpr(E, /*ConvertResultToRValue=*/false,
81                              /*DestroyToplevelScope=*/Kind ==
82                                  ConstantExprKind::ClassTemplateArgument);
83   if (Res.isInvalid()) {
84     C.cleanup();
85     Stk.clearTo(StackSizeBefore);
86     return false;
87   }
88 
89   if (!Recursing) {
90     assert(Stk.empty());
91     C.cleanup();
92 #ifndef NDEBUG
93     // Make sure we don't rely on some value being still alive in
94     // InterpStack memory.
95     Stk.clearTo(StackSizeBefore);
96 #endif
97   }
98 
99   Result = Res.toAPValue();
100   return true;
101 }
102 
103 bool Context::evaluateAsInitializer(State &Parent, const VarDecl *VD,
104                                     APValue &Result) {
105   ++EvalID;
106   bool Recursing = !Stk.empty();
107   size_t StackSizeBefore = Stk.size();
108   Compiler<EvalEmitter> C(*this, *P, Parent, Stk);
109 
110   bool CheckGlobalInitialized =
111       shouldBeGloballyIndexed(VD) &&
112       (VD->getType()->isRecordType() || VD->getType()->isArrayType());
113   auto Res = C.interpretDecl(VD, CheckGlobalInitialized);
114   if (Res.isInvalid()) {
115     C.cleanup();
116     Stk.clearTo(StackSizeBefore);
117 
118     return false;
119   }
120 
121   if (!Recursing) {
122     assert(Stk.empty());
123     C.cleanup();
124 #ifndef NDEBUG
125     // Make sure we don't rely on some value being still alive in
126     // InterpStack memory.
127     Stk.clearTo(StackSizeBefore);
128 #endif
129   }
130 
131   Result = Res.toAPValue();
132   return true;
133 }
134 
135 const LangOptions &Context::getLangOpts() const { return Ctx.getLangOpts(); }
136 
137 std::optional<PrimType> Context::classify(QualType T) const {
138   if (T->isBooleanType())
139     return PT_Bool;
140 
141   // We map these to primitive arrays.
142   if (T->isAnyComplexType() || T->isVectorType())
143     return std::nullopt;
144 
145   if (T->isSignedIntegerOrEnumerationType()) {
146     switch (Ctx.getIntWidth(T)) {
147     case 64:
148       return PT_Sint64;
149     case 32:
150       return PT_Sint32;
151     case 16:
152       return PT_Sint16;
153     case 8:
154       return PT_Sint8;
155     default:
156       return PT_IntAPS;
157     }
158   }
159 
160   if (T->isUnsignedIntegerOrEnumerationType()) {
161     switch (Ctx.getIntWidth(T)) {
162     case 64:
163       return PT_Uint64;
164     case 32:
165       return PT_Uint32;
166     case 16:
167       return PT_Uint16;
168     case 8:
169       return PT_Uint8;
170     case 1:
171       // Might happen for enum types.
172       return PT_Bool;
173     default:
174       return PT_IntAP;
175     }
176   }
177 
178   if (T->isNullPtrType())
179     return PT_Ptr;
180 
181   if (T->isFloatingType())
182     return PT_Float;
183 
184   if (T->isSpecificBuiltinType(BuiltinType::BoundMember) ||
185       T->isMemberPointerType())
186     return PT_MemberPtr;
187 
188   if (T->isFunctionPointerType() || T->isFunctionReferenceType() ||
189       T->isFunctionType() || T->isBlockPointerType())
190     return PT_FnPtr;
191 
192   if (T->isPointerOrReferenceType() || T->isObjCObjectPointerType())
193     return PT_Ptr;
194 
195   if (const auto *AT = T->getAs<AtomicType>())
196     return classify(AT->getValueType());
197 
198   if (const auto *DT = dyn_cast<DecltypeType>(T))
199     return classify(DT->getUnderlyingType());
200 
201   return std::nullopt;
202 }
203 
204 unsigned Context::getCharBit() const {
205   return Ctx.getTargetInfo().getCharWidth();
206 }
207 
208 /// Simple wrapper around getFloatTypeSemantics() to make code a
209 /// little shorter.
210 const llvm::fltSemantics &Context::getFloatSemantics(QualType T) const {
211   return Ctx.getFloatTypeSemantics(T);
212 }
213 
214 bool Context::Run(State &Parent, const Function *Func, APValue &Result) {
215 
216   {
217     InterpState State(Parent, *P, Stk, *this);
218     State.Current = new InterpFrame(State, Func, /*Caller=*/nullptr, CodePtr(),
219                                     Func->getArgSize());
220     if (Interpret(State, Result)) {
221       assert(Stk.empty());
222       return true;
223     }
224 
225     // State gets destroyed here, so the Stk.clear() below doesn't accidentally
226     // remove values the State's destructor might access.
227   }
228 
229   Stk.clear();
230   return false;
231 }
232 
233 // TODO: Virtual bases?
234 const CXXMethodDecl *
235 Context::getOverridingFunction(const CXXRecordDecl *DynamicDecl,
236                                const CXXRecordDecl *StaticDecl,
237                                const CXXMethodDecl *InitialFunction) const {
238   assert(DynamicDecl);
239   assert(StaticDecl);
240   assert(InitialFunction);
241 
242   const CXXRecordDecl *CurRecord = DynamicDecl;
243   const CXXMethodDecl *FoundFunction = InitialFunction;
244   for (;;) {
245     const CXXMethodDecl *Overrider =
246         FoundFunction->getCorrespondingMethodDeclaredInClass(CurRecord, false);
247     if (Overrider)
248       return Overrider;
249 
250     // Common case of only one base class.
251     if (CurRecord->getNumBases() == 1) {
252       CurRecord = CurRecord->bases_begin()->getType()->getAsCXXRecordDecl();
253       continue;
254     }
255 
256     // Otherwise, go to the base class that will lead to the StaticDecl.
257     for (const CXXBaseSpecifier &Spec : CurRecord->bases()) {
258       const CXXRecordDecl *Base = Spec.getType()->getAsCXXRecordDecl();
259       if (Base == StaticDecl || Base->isDerivedFrom(StaticDecl)) {
260         CurRecord = Base;
261         break;
262       }
263     }
264   }
265 
266   llvm_unreachable(
267       "Couldn't find an overriding function in the class hierarchy?");
268   return nullptr;
269 }
270 
271 const Function *Context::getOrCreateFunction(const FunctionDecl *FD) {
272   assert(FD);
273   const Function *Func = P->getFunction(FD);
274   bool IsBeingCompiled = Func && Func->isDefined() && !Func->isFullyCompiled();
275   bool WasNotDefined = Func && !Func->isConstexpr() && !Func->isDefined();
276 
277   if (IsBeingCompiled)
278     return Func;
279 
280   if (!Func || WasNotDefined) {
281     if (auto F = Compiler<ByteCodeEmitter>(*this, *P).compileFunc(FD))
282       Func = F;
283   }
284 
285   return Func;
286 }
287 
288 unsigned Context::collectBaseOffset(const RecordDecl *BaseDecl,
289                                     const RecordDecl *DerivedDecl) const {
290   assert(BaseDecl);
291   assert(DerivedDecl);
292   const auto *FinalDecl = cast<CXXRecordDecl>(BaseDecl);
293   const RecordDecl *CurDecl = DerivedDecl;
294   const Record *CurRecord = P->getOrCreateRecord(CurDecl);
295   assert(CurDecl && FinalDecl);
296 
297   unsigned OffsetSum = 0;
298   for (;;) {
299     assert(CurRecord->getNumBases() > 0);
300     // One level up
301     for (const Record::Base &B : CurRecord->bases()) {
302       const auto *BaseDecl = cast<CXXRecordDecl>(B.Decl);
303 
304       if (BaseDecl == FinalDecl || BaseDecl->isDerivedFrom(FinalDecl)) {
305         OffsetSum += B.Offset;
306         CurRecord = B.R;
307         CurDecl = BaseDecl;
308         break;
309       }
310     }
311     if (CurDecl == FinalDecl)
312       break;
313   }
314 
315   assert(OffsetSum > 0);
316   return OffsetSum;
317 }
318 
319 const Record *Context::getRecord(const RecordDecl *D) const {
320   return P->getOrCreateRecord(D);
321 }
322