xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (revision 7a6dacaca14b62ca4b74406814becb87a3fefac0)
1*7a6dacacSDimitry Andric //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2*7a6dacacSDimitry Andric //
3*7a6dacacSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*7a6dacacSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*7a6dacacSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*7a6dacacSDimitry Andric //
7*7a6dacacSDimitry Andric //===----------------------------------------------------------------------===//
8*7a6dacacSDimitry Andric ///
9*7a6dacacSDimitry Andric /// \file
10*7a6dacacSDimitry Andric /// This file contains the IR transform to lower external or indirect calls for
11*7a6dacacSDimitry Andric /// the ARM64EC calling convention. Such calls must go through the runtime, so
12*7a6dacacSDimitry Andric /// we can translate the calling convention for calls into the emulator.
13*7a6dacacSDimitry Andric ///
14*7a6dacacSDimitry Andric /// This subsumes Control Flow Guard handling.
15*7a6dacacSDimitry Andric ///
16*7a6dacacSDimitry Andric //===----------------------------------------------------------------------===//
17*7a6dacacSDimitry Andric 
18*7a6dacacSDimitry Andric #include "AArch64.h"
19*7a6dacacSDimitry Andric #include "llvm/ADT/SetVector.h"
20*7a6dacacSDimitry Andric #include "llvm/ADT/SmallString.h"
21*7a6dacacSDimitry Andric #include "llvm/ADT/SmallVector.h"
22*7a6dacacSDimitry Andric #include "llvm/ADT/Statistic.h"
23*7a6dacacSDimitry Andric #include "llvm/IR/CallingConv.h"
24*7a6dacacSDimitry Andric #include "llvm/IR/IRBuilder.h"
25*7a6dacacSDimitry Andric #include "llvm/IR/Instruction.h"
26*7a6dacacSDimitry Andric #include "llvm/InitializePasses.h"
27*7a6dacacSDimitry Andric #include "llvm/Pass.h"
28*7a6dacacSDimitry Andric #include "llvm/Support/CommandLine.h"
29*7a6dacacSDimitry Andric #include "llvm/TargetParser/Triple.h"
30*7a6dacacSDimitry Andric 
31*7a6dacacSDimitry Andric using namespace llvm;
32*7a6dacacSDimitry Andric 
33*7a6dacacSDimitry Andric using OperandBundleDef = OperandBundleDefT<Value *>;
34*7a6dacacSDimitry Andric 
35*7a6dacacSDimitry Andric #define DEBUG_TYPE "arm64eccalllowering"
36*7a6dacacSDimitry Andric 
37*7a6dacacSDimitry Andric STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
38*7a6dacacSDimitry Andric 
39*7a6dacacSDimitry Andric static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
40*7a6dacacSDimitry Andric                                            cl::Hidden, cl::init(true));
41*7a6dacacSDimitry Andric static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
42*7a6dacacSDimitry Andric                                     cl::init(true));
43*7a6dacacSDimitry Andric 
44*7a6dacacSDimitry Andric namespace {
45*7a6dacacSDimitry Andric 
46*7a6dacacSDimitry Andric class AArch64Arm64ECCallLowering : public ModulePass {
47*7a6dacacSDimitry Andric public:
48*7a6dacacSDimitry Andric   static char ID;
49*7a6dacacSDimitry Andric   AArch64Arm64ECCallLowering() : ModulePass(ID) {
50*7a6dacacSDimitry Andric     initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
51*7a6dacacSDimitry Andric   }
52*7a6dacacSDimitry Andric 
53*7a6dacacSDimitry Andric   Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
54*7a6dacacSDimitry Andric   Function *buildEntryThunk(Function *F);
55*7a6dacacSDimitry Andric   void lowerCall(CallBase *CB);
56*7a6dacacSDimitry Andric   Function *buildGuestExitThunk(Function *F);
57*7a6dacacSDimitry Andric   bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
58*7a6dacacSDimitry Andric   bool runOnModule(Module &M) override;
59*7a6dacacSDimitry Andric 
60*7a6dacacSDimitry Andric private:
61*7a6dacacSDimitry Andric   int cfguard_module_flag = 0;
62*7a6dacacSDimitry Andric   FunctionType *GuardFnType = nullptr;
63*7a6dacacSDimitry Andric   PointerType *GuardFnPtrType = nullptr;
64*7a6dacacSDimitry Andric   Constant *GuardFnCFGlobal = nullptr;
65*7a6dacacSDimitry Andric   Constant *GuardFnGlobal = nullptr;
66*7a6dacacSDimitry Andric   Module *M = nullptr;
67*7a6dacacSDimitry Andric 
68*7a6dacacSDimitry Andric   Type *PtrTy;
69*7a6dacacSDimitry Andric   Type *I64Ty;
70*7a6dacacSDimitry Andric   Type *VoidTy;
71*7a6dacacSDimitry Andric 
72*7a6dacacSDimitry Andric   void getThunkType(FunctionType *FT, AttributeList AttrList, bool EntryThunk,
73*7a6dacacSDimitry Andric                     raw_ostream &Out, FunctionType *&Arm64Ty,
74*7a6dacacSDimitry Andric                     FunctionType *&X64Ty);
75*7a6dacacSDimitry Andric   void getThunkRetType(FunctionType *FT, AttributeList AttrList,
76*7a6dacacSDimitry Andric                        raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
77*7a6dacacSDimitry Andric                        SmallVectorImpl<Type *> &Arm64ArgTypes,
78*7a6dacacSDimitry Andric                        SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
79*7a6dacacSDimitry Andric   void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
80*7a6dacacSDimitry Andric                         raw_ostream &Out,
81*7a6dacacSDimitry Andric                         SmallVectorImpl<Type *> &Arm64ArgTypes,
82*7a6dacacSDimitry Andric                         SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
83*7a6dacacSDimitry Andric   void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
84*7a6dacacSDimitry Andric                              uint64_t ArgSizeBytes, raw_ostream &Out,
85*7a6dacacSDimitry Andric                              Type *&Arm64Ty, Type *&X64Ty);
86*7a6dacacSDimitry Andric };
87*7a6dacacSDimitry Andric 
88*7a6dacacSDimitry Andric } // end anonymous namespace
89*7a6dacacSDimitry Andric 
90*7a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::getThunkType(FunctionType *FT,
91*7a6dacacSDimitry Andric                                               AttributeList AttrList,
92*7a6dacacSDimitry Andric                                               bool EntryThunk, raw_ostream &Out,
93*7a6dacacSDimitry Andric                                               FunctionType *&Arm64Ty,
94*7a6dacacSDimitry Andric                                               FunctionType *&X64Ty) {
95*7a6dacacSDimitry Andric   Out << (EntryThunk ? "$ientry_thunk$cdecl$" : "$iexit_thunk$cdecl$");
96*7a6dacacSDimitry Andric 
97*7a6dacacSDimitry Andric   Type *Arm64RetTy;
98*7a6dacacSDimitry Andric   Type *X64RetTy;
99*7a6dacacSDimitry Andric 
100*7a6dacacSDimitry Andric   SmallVector<Type *> Arm64ArgTypes;
101*7a6dacacSDimitry Andric   SmallVector<Type *> X64ArgTypes;
102*7a6dacacSDimitry Andric 
103*7a6dacacSDimitry Andric   // The first argument to a thunk is the called function, stored in x9.
104*7a6dacacSDimitry Andric   // For exit thunks, we pass the called function down to the emulator;
105*7a6dacacSDimitry Andric   // for entry thunks, we just call the Arm64 function directly.
106*7a6dacacSDimitry Andric   if (!EntryThunk)
107*7a6dacacSDimitry Andric     Arm64ArgTypes.push_back(PtrTy);
108*7a6dacacSDimitry Andric   X64ArgTypes.push_back(PtrTy);
109*7a6dacacSDimitry Andric 
110*7a6dacacSDimitry Andric   bool HasSretPtr = false;
111*7a6dacacSDimitry Andric   getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
112*7a6dacacSDimitry Andric                   X64ArgTypes, HasSretPtr);
113*7a6dacacSDimitry Andric 
114*7a6dacacSDimitry Andric   getThunkArgTypes(FT, AttrList, Out, Arm64ArgTypes, X64ArgTypes, HasSretPtr);
115*7a6dacacSDimitry Andric 
116*7a6dacacSDimitry Andric   Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
117*7a6dacacSDimitry Andric   X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
118*7a6dacacSDimitry Andric }
119*7a6dacacSDimitry Andric 
120*7a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::getThunkArgTypes(
121*7a6dacacSDimitry Andric     FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
122*7a6dacacSDimitry Andric     SmallVectorImpl<Type *> &Arm64ArgTypes,
123*7a6dacacSDimitry Andric     SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
124*7a6dacacSDimitry Andric 
125*7a6dacacSDimitry Andric   Out << "$";
126*7a6dacacSDimitry Andric   if (FT->isVarArg()) {
127*7a6dacacSDimitry Andric     // We treat the variadic function's thunk as a normal function
128*7a6dacacSDimitry Andric     // with the following type on the ARM side:
129*7a6dacacSDimitry Andric     //   rettype exitthunk(
130*7a6dacacSDimitry Andric     //     ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
131*7a6dacacSDimitry Andric     //
132*7a6dacacSDimitry Andric     // that can coverage all types of variadic function.
133*7a6dacacSDimitry Andric     // x9 is similar to normal exit thunk, store the called function.
134*7a6dacacSDimitry Andric     // x0-x3 is the arguments be stored in registers.
135*7a6dacacSDimitry Andric     // x4 is the address of the arguments on the stack.
136*7a6dacacSDimitry Andric     // x5 is the size of the arguments on the stack.
137*7a6dacacSDimitry Andric     //
138*7a6dacacSDimitry Andric     // On the x64 side, it's the same except that x5 isn't set.
139*7a6dacacSDimitry Andric     //
140*7a6dacacSDimitry Andric     // If both the ARM and X64 sides are sret, there are only three
141*7a6dacacSDimitry Andric     // arguments in registers.
142*7a6dacacSDimitry Andric     //
143*7a6dacacSDimitry Andric     // If the X64 side is sret, but the ARM side isn't, we pass an extra value
144*7a6dacacSDimitry Andric     // to/from the X64 side, and let SelectionDAG transform it into a memory
145*7a6dacacSDimitry Andric     // location.
146*7a6dacacSDimitry Andric     Out << "varargs";
147*7a6dacacSDimitry Andric 
148*7a6dacacSDimitry Andric     // x0-x3
149*7a6dacacSDimitry Andric     for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
150*7a6dacacSDimitry Andric       Arm64ArgTypes.push_back(I64Ty);
151*7a6dacacSDimitry Andric       X64ArgTypes.push_back(I64Ty);
152*7a6dacacSDimitry Andric     }
153*7a6dacacSDimitry Andric 
154*7a6dacacSDimitry Andric     // x4
155*7a6dacacSDimitry Andric     Arm64ArgTypes.push_back(PtrTy);
156*7a6dacacSDimitry Andric     X64ArgTypes.push_back(PtrTy);
157*7a6dacacSDimitry Andric     // x5
158*7a6dacacSDimitry Andric     Arm64ArgTypes.push_back(I64Ty);
159*7a6dacacSDimitry Andric     // FIXME: x5 isn't actually passed/used by the x64 side; revisit once we
160*7a6dacacSDimitry Andric     // have proper isel for varargs
161*7a6dacacSDimitry Andric     X64ArgTypes.push_back(I64Ty);
162*7a6dacacSDimitry Andric     return;
163*7a6dacacSDimitry Andric   }
164*7a6dacacSDimitry Andric 
165*7a6dacacSDimitry Andric   unsigned I = 0;
166*7a6dacacSDimitry Andric   if (HasSretPtr)
167*7a6dacacSDimitry Andric     I++;
168*7a6dacacSDimitry Andric 
169*7a6dacacSDimitry Andric   if (I == FT->getNumParams()) {
170*7a6dacacSDimitry Andric     Out << "v";
171*7a6dacacSDimitry Andric     return;
172*7a6dacacSDimitry Andric   }
173*7a6dacacSDimitry Andric 
174*7a6dacacSDimitry Andric   for (unsigned E = FT->getNumParams(); I != E; ++I) {
175*7a6dacacSDimitry Andric     Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
176*7a6dacacSDimitry Andric #if 0
177*7a6dacacSDimitry Andric     // FIXME: Need more information about argument size; see
178*7a6dacacSDimitry Andric     // https://reviews.llvm.org/D132926
179*7a6dacacSDimitry Andric     uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
180*7a6dacacSDimitry Andric #else
181*7a6dacacSDimitry Andric     uint64_t ArgSizeBytes = 0;
182*7a6dacacSDimitry Andric #endif
183*7a6dacacSDimitry Andric     Type *Arm64Ty, *X64Ty;
184*7a6dacacSDimitry Andric     canonicalizeThunkType(FT->getParamType(I), ParamAlign,
185*7a6dacacSDimitry Andric                           /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
186*7a6dacacSDimitry Andric     Arm64ArgTypes.push_back(Arm64Ty);
187*7a6dacacSDimitry Andric     X64ArgTypes.push_back(X64Ty);
188*7a6dacacSDimitry Andric   }
189*7a6dacacSDimitry Andric }
190*7a6dacacSDimitry Andric 
191*7a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::getThunkRetType(
192*7a6dacacSDimitry Andric     FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
193*7a6dacacSDimitry Andric     Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
194*7a6dacacSDimitry Andric     SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
195*7a6dacacSDimitry Andric   Type *T = FT->getReturnType();
196*7a6dacacSDimitry Andric #if 0
197*7a6dacacSDimitry Andric   // FIXME: Need more information about argument size; see
198*7a6dacacSDimitry Andric   // https://reviews.llvm.org/D132926
199*7a6dacacSDimitry Andric   uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
200*7a6dacacSDimitry Andric #else
201*7a6dacacSDimitry Andric   int64_t ArgSizeBytes = 0;
202*7a6dacacSDimitry Andric #endif
203*7a6dacacSDimitry Andric   if (T->isVoidTy()) {
204*7a6dacacSDimitry Andric     if (FT->getNumParams()) {
205*7a6dacacSDimitry Andric       auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet);
206*7a6dacacSDimitry Andric       auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg);
207*7a6dacacSDimitry Andric       if (SRetAttr.isValid() && InRegAttr.isValid()) {
208*7a6dacacSDimitry Andric         // sret+inreg indicates a call that returns a C++ class value. This is
209*7a6dacacSDimitry Andric         // actually equivalent to just passing and returning a void* pointer
210*7a6dacacSDimitry Andric         // as the first argument. Translate it that way, instead of trying
211*7a6dacacSDimitry Andric         // to model "inreg" in the thunk's calling convention, to simplify
212*7a6dacacSDimitry Andric         // the rest of the code.
213*7a6dacacSDimitry Andric         Out << "i8";
214*7a6dacacSDimitry Andric         Arm64RetTy = I64Ty;
215*7a6dacacSDimitry Andric         X64RetTy = I64Ty;
216*7a6dacacSDimitry Andric         return;
217*7a6dacacSDimitry Andric       }
218*7a6dacacSDimitry Andric       if (SRetAttr.isValid()) {
219*7a6dacacSDimitry Andric         // FIXME: Sanity-check the sret type; if it's an integer or pointer,
220*7a6dacacSDimitry Andric         // we'll get screwy mangling/codegen.
221*7a6dacacSDimitry Andric         // FIXME: For large struct types, mangle as an integer argument and
222*7a6dacacSDimitry Andric         // integer return, so we can reuse more thunks, instead of "m" syntax.
223*7a6dacacSDimitry Andric         // (MSVC mangles this case as an integer return with no argument, but
224*7a6dacacSDimitry Andric         // that's a miscompile.)
225*7a6dacacSDimitry Andric         Type *SRetType = SRetAttr.getValueAsType();
226*7a6dacacSDimitry Andric         Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
227*7a6dacacSDimitry Andric         Type *Arm64Ty, *X64Ty;
228*7a6dacacSDimitry Andric         canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
229*7a6dacacSDimitry Andric                               Out, Arm64Ty, X64Ty);
230*7a6dacacSDimitry Andric         Arm64RetTy = VoidTy;
231*7a6dacacSDimitry Andric         X64RetTy = VoidTy;
232*7a6dacacSDimitry Andric         Arm64ArgTypes.push_back(FT->getParamType(0));
233*7a6dacacSDimitry Andric         X64ArgTypes.push_back(FT->getParamType(0));
234*7a6dacacSDimitry Andric         HasSretPtr = true;
235*7a6dacacSDimitry Andric         return;
236*7a6dacacSDimitry Andric       }
237*7a6dacacSDimitry Andric     }
238*7a6dacacSDimitry Andric 
239*7a6dacacSDimitry Andric     Out << "v";
240*7a6dacacSDimitry Andric     Arm64RetTy = VoidTy;
241*7a6dacacSDimitry Andric     X64RetTy = VoidTy;
242*7a6dacacSDimitry Andric     return;
243*7a6dacacSDimitry Andric   }
244*7a6dacacSDimitry Andric 
245*7a6dacacSDimitry Andric   canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy,
246*7a6dacacSDimitry Andric                         X64RetTy);
247*7a6dacacSDimitry Andric   if (X64RetTy->isPointerTy()) {
248*7a6dacacSDimitry Andric     // If the X64 type is canonicalized to a pointer, that means it's
249*7a6dacacSDimitry Andric     // passed/returned indirectly. For a return value, that means it's an
250*7a6dacacSDimitry Andric     // sret pointer.
251*7a6dacacSDimitry Andric     X64ArgTypes.push_back(X64RetTy);
252*7a6dacacSDimitry Andric     X64RetTy = VoidTy;
253*7a6dacacSDimitry Andric   }
254*7a6dacacSDimitry Andric }
255*7a6dacacSDimitry Andric 
256*7a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::canonicalizeThunkType(
257*7a6dacacSDimitry Andric     Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
258*7a6dacacSDimitry Andric     Type *&Arm64Ty, Type *&X64Ty) {
259*7a6dacacSDimitry Andric   if (T->isFloatTy()) {
260*7a6dacacSDimitry Andric     Out << "f";
261*7a6dacacSDimitry Andric     Arm64Ty = T;
262*7a6dacacSDimitry Andric     X64Ty = T;
263*7a6dacacSDimitry Andric     return;
264*7a6dacacSDimitry Andric   }
265*7a6dacacSDimitry Andric 
266*7a6dacacSDimitry Andric   if (T->isDoubleTy()) {
267*7a6dacacSDimitry Andric     Out << "d";
268*7a6dacacSDimitry Andric     Arm64Ty = T;
269*7a6dacacSDimitry Andric     X64Ty = T;
270*7a6dacacSDimitry Andric     return;
271*7a6dacacSDimitry Andric   }
272*7a6dacacSDimitry Andric 
273*7a6dacacSDimitry Andric   if (T->isFloatingPointTy()) {
274*7a6dacacSDimitry Andric     report_fatal_error(
275*7a6dacacSDimitry Andric         "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
276*7a6dacacSDimitry Andric   }
277*7a6dacacSDimitry Andric 
278*7a6dacacSDimitry Andric   auto &DL = M->getDataLayout();
279*7a6dacacSDimitry Andric 
280*7a6dacacSDimitry Andric   if (auto *StructTy = dyn_cast<StructType>(T))
281*7a6dacacSDimitry Andric     if (StructTy->getNumElements() == 1)
282*7a6dacacSDimitry Andric       T = StructTy->getElementType(0);
283*7a6dacacSDimitry Andric 
284*7a6dacacSDimitry Andric   if (T->isArrayTy()) {
285*7a6dacacSDimitry Andric     Type *ElementTy = T->getArrayElementType();
286*7a6dacacSDimitry Andric     uint64_t ElementCnt = T->getArrayNumElements();
287*7a6dacacSDimitry Andric     uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
288*7a6dacacSDimitry Andric     uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
289*7a6dacacSDimitry Andric     if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
290*7a6dacacSDimitry Andric       Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
291*7a6dacacSDimitry Andric       if (Alignment.value() >= 8 && !T->isPointerTy())
292*7a6dacacSDimitry Andric         Out << "a" << Alignment.value();
293*7a6dacacSDimitry Andric       Arm64Ty = T;
294*7a6dacacSDimitry Andric       if (TotalSizeBytes <= 8) {
295*7a6dacacSDimitry Andric         // Arm64 returns small structs of float/double in float registers;
296*7a6dacacSDimitry Andric         // X64 uses RAX.
297*7a6dacacSDimitry Andric         X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8);
298*7a6dacacSDimitry Andric       } else {
299*7a6dacacSDimitry Andric         // Struct is passed directly on Arm64, but indirectly on X64.
300*7a6dacacSDimitry Andric         X64Ty = PtrTy;
301*7a6dacacSDimitry Andric       }
302*7a6dacacSDimitry Andric       return;
303*7a6dacacSDimitry Andric     } else if (T->isFloatingPointTy()) {
304*7a6dacacSDimitry Andric       report_fatal_error("Only 32 and 64 bit floating points are supported for "
305*7a6dacacSDimitry Andric                          "ARM64EC thunks");
306*7a6dacacSDimitry Andric     }
307*7a6dacacSDimitry Andric   }
308*7a6dacacSDimitry Andric 
309*7a6dacacSDimitry Andric   if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
310*7a6dacacSDimitry Andric     Out << "i8";
311*7a6dacacSDimitry Andric     Arm64Ty = I64Ty;
312*7a6dacacSDimitry Andric     X64Ty = I64Ty;
313*7a6dacacSDimitry Andric     return;
314*7a6dacacSDimitry Andric   }
315*7a6dacacSDimitry Andric 
316*7a6dacacSDimitry Andric   unsigned TypeSize = ArgSizeBytes;
317*7a6dacacSDimitry Andric   if (TypeSize == 0)
318*7a6dacacSDimitry Andric     TypeSize = DL.getTypeSizeInBits(T) / 8;
319*7a6dacacSDimitry Andric   Out << "m";
320*7a6dacacSDimitry Andric   if (TypeSize != 4)
321*7a6dacacSDimitry Andric     Out << TypeSize;
322*7a6dacacSDimitry Andric   if (Alignment.value() >= 8 && !T->isPointerTy())
323*7a6dacacSDimitry Andric     Out << "a" << Alignment.value();
324*7a6dacacSDimitry Andric   // FIXME: Try to canonicalize Arm64Ty more thoroughly?
325*7a6dacacSDimitry Andric   Arm64Ty = T;
326*7a6dacacSDimitry Andric   if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
327*7a6dacacSDimitry Andric     // Pass directly in an integer register
328*7a6dacacSDimitry Andric     X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8);
329*7a6dacacSDimitry Andric   } else {
330*7a6dacacSDimitry Andric     // Passed directly on Arm64, but indirectly on X64.
331*7a6dacacSDimitry Andric     X64Ty = PtrTy;
332*7a6dacacSDimitry Andric   }
333*7a6dacacSDimitry Andric }
334*7a6dacacSDimitry Andric 
335*7a6dacacSDimitry Andric // This function builds the "exit thunk", a function which translates
336*7a6dacacSDimitry Andric // arguments and return values when calling x64 code from AArch64 code.
337*7a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
338*7a6dacacSDimitry Andric                                                      AttributeList Attrs) {
339*7a6dacacSDimitry Andric   SmallString<256> ExitThunkName;
340*7a6dacacSDimitry Andric   llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
341*7a6dacacSDimitry Andric   FunctionType *Arm64Ty, *X64Ty;
342*7a6dacacSDimitry Andric   getThunkType(FT, Attrs, /*EntryThunk*/ false, ExitThunkStream, Arm64Ty,
343*7a6dacacSDimitry Andric                X64Ty);
344*7a6dacacSDimitry Andric   if (Function *F = M->getFunction(ExitThunkName))
345*7a6dacacSDimitry Andric     return F;
346*7a6dacacSDimitry Andric 
347*7a6dacacSDimitry Andric   Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
348*7a6dacacSDimitry Andric                                  ExitThunkName, M);
349*7a6dacacSDimitry Andric   F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
350*7a6dacacSDimitry Andric   F->setSection(".wowthk$aa");
351*7a6dacacSDimitry Andric   F->setComdat(M->getOrInsertComdat(ExitThunkName));
352*7a6dacacSDimitry Andric   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
353*7a6dacacSDimitry Andric   F->addFnAttr("frame-pointer", "all");
354*7a6dacacSDimitry Andric   // Only copy sret from the first argument. For C++ instance methods, clang can
355*7a6dacacSDimitry Andric   // stick an sret marking on a later argument, but it doesn't actually affect
356*7a6dacacSDimitry Andric   // the ABI, so we can omit it. This avoids triggering a verifier assertion.
357*7a6dacacSDimitry Andric   if (FT->getNumParams()) {
358*7a6dacacSDimitry Andric     auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
359*7a6dacacSDimitry Andric     auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
360*7a6dacacSDimitry Andric     if (SRet.isValid() && !InReg.isValid())
361*7a6dacacSDimitry Andric       F->addParamAttr(1, SRet);
362*7a6dacacSDimitry Andric   }
363*7a6dacacSDimitry Andric   // FIXME: Copy anything other than sret?  Shouldn't be necessary for normal
364*7a6dacacSDimitry Andric   // C ABI, but might show up in other cases.
365*7a6dacacSDimitry Andric   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
366*7a6dacacSDimitry Andric   IRBuilder<> IRB(BB);
367*7a6dacacSDimitry Andric   Value *CalleePtr =
368*7a6dacacSDimitry Andric       M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
369*7a6dacacSDimitry Andric   Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
370*7a6dacacSDimitry Andric   auto &DL = M->getDataLayout();
371*7a6dacacSDimitry Andric   SmallVector<Value *> Args;
372*7a6dacacSDimitry Andric 
373*7a6dacacSDimitry Andric   // Pass the called function in x9.
374*7a6dacacSDimitry Andric   Args.push_back(F->arg_begin());
375*7a6dacacSDimitry Andric 
376*7a6dacacSDimitry Andric   Type *RetTy = Arm64Ty->getReturnType();
377*7a6dacacSDimitry Andric   if (RetTy != X64Ty->getReturnType()) {
378*7a6dacacSDimitry Andric     // If the return type is an array or struct, translate it. Values of size
379*7a6dacacSDimitry Andric     // 8 or less go into RAX; bigger values go into memory, and we pass a
380*7a6dacacSDimitry Andric     // pointer.
381*7a6dacacSDimitry Andric     if (DL.getTypeStoreSize(RetTy) > 8) {
382*7a6dacacSDimitry Andric       Args.push_back(IRB.CreateAlloca(RetTy));
383*7a6dacacSDimitry Andric     }
384*7a6dacacSDimitry Andric   }
385*7a6dacacSDimitry Andric 
386*7a6dacacSDimitry Andric   for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) {
387*7a6dacacSDimitry Andric     // Translate arguments from AArch64 calling convention to x86 calling
388*7a6dacacSDimitry Andric     // convention.
389*7a6dacacSDimitry Andric     //
390*7a6dacacSDimitry Andric     // For simple types, we don't need to do any translation: they're
391*7a6dacacSDimitry Andric     // represented the same way. (Implicit sign extension is not part of
392*7a6dacacSDimitry Andric     // either convention.)
393*7a6dacacSDimitry Andric     //
394*7a6dacacSDimitry Andric     // The big thing we have to worry about is struct types... but
395*7a6dacacSDimitry Andric     // fortunately AArch64 clang is pretty friendly here: the cases that need
396*7a6dacacSDimitry Andric     // translation are always passed as a struct or array. (If we run into
397*7a6dacacSDimitry Andric     // some cases where this doesn't work, we can teach clang to mark it up
398*7a6dacacSDimitry Andric     // with an attribute.)
399*7a6dacacSDimitry Andric     //
400*7a6dacacSDimitry Andric     // The first argument is the called function, stored in x9.
401*7a6dacacSDimitry Andric     if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
402*7a6dacacSDimitry Andric         DL.getTypeStoreSize(Arg.getType()) > 8) {
403*7a6dacacSDimitry Andric       Value *Mem = IRB.CreateAlloca(Arg.getType());
404*7a6dacacSDimitry Andric       IRB.CreateStore(&Arg, Mem);
405*7a6dacacSDimitry Andric       if (DL.getTypeStoreSize(Arg.getType()) <= 8) {
406*7a6dacacSDimitry Andric         Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
407*7a6dacacSDimitry Andric         Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
408*7a6dacacSDimitry Andric       } else
409*7a6dacacSDimitry Andric         Args.push_back(Mem);
410*7a6dacacSDimitry Andric     } else {
411*7a6dacacSDimitry Andric       Args.push_back(&Arg);
412*7a6dacacSDimitry Andric     }
413*7a6dacacSDimitry Andric   }
414*7a6dacacSDimitry Andric   // FIXME: Transfer necessary attributes? sret? anything else?
415*7a6dacacSDimitry Andric 
416*7a6dacacSDimitry Andric   Callee = IRB.CreateBitCast(Callee, PtrTy);
417*7a6dacacSDimitry Andric   CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
418*7a6dacacSDimitry Andric   Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
419*7a6dacacSDimitry Andric 
420*7a6dacacSDimitry Andric   Value *RetVal = Call;
421*7a6dacacSDimitry Andric   if (RetTy != X64Ty->getReturnType()) {
422*7a6dacacSDimitry Andric     // If we rewrote the return type earlier, convert the return value to
423*7a6dacacSDimitry Andric     // the proper type.
424*7a6dacacSDimitry Andric     if (DL.getTypeStoreSize(RetTy) > 8) {
425*7a6dacacSDimitry Andric       RetVal = IRB.CreateLoad(RetTy, Args[1]);
426*7a6dacacSDimitry Andric     } else {
427*7a6dacacSDimitry Andric       Value *CastAlloca = IRB.CreateAlloca(RetTy);
428*7a6dacacSDimitry Andric       IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
429*7a6dacacSDimitry Andric       RetVal = IRB.CreateLoad(RetTy, CastAlloca);
430*7a6dacacSDimitry Andric     }
431*7a6dacacSDimitry Andric   }
432*7a6dacacSDimitry Andric 
433*7a6dacacSDimitry Andric   if (RetTy->isVoidTy())
434*7a6dacacSDimitry Andric     IRB.CreateRetVoid();
435*7a6dacacSDimitry Andric   else
436*7a6dacacSDimitry Andric     IRB.CreateRet(RetVal);
437*7a6dacacSDimitry Andric   return F;
438*7a6dacacSDimitry Andric }
439*7a6dacacSDimitry Andric 
440*7a6dacacSDimitry Andric // This function builds the "entry thunk", a function which translates
441*7a6dacacSDimitry Andric // arguments and return values when calling AArch64 code from x64 code.
442*7a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
443*7a6dacacSDimitry Andric   SmallString<256> EntryThunkName;
444*7a6dacacSDimitry Andric   llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
445*7a6dacacSDimitry Andric   FunctionType *Arm64Ty, *X64Ty;
446*7a6dacacSDimitry Andric   getThunkType(F->getFunctionType(), F->getAttributes(), /*EntryThunk*/ true,
447*7a6dacacSDimitry Andric                EntryThunkStream, Arm64Ty, X64Ty);
448*7a6dacacSDimitry Andric   if (Function *F = M->getFunction(EntryThunkName))
449*7a6dacacSDimitry Andric     return F;
450*7a6dacacSDimitry Andric 
451*7a6dacacSDimitry Andric   Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
452*7a6dacacSDimitry Andric                                      EntryThunkName, M);
453*7a6dacacSDimitry Andric   Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
454*7a6dacacSDimitry Andric   Thunk->setSection(".wowthk$aa");
455*7a6dacacSDimitry Andric   Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
456*7a6dacacSDimitry Andric   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
457*7a6dacacSDimitry Andric   Thunk->addFnAttr("frame-pointer", "all");
458*7a6dacacSDimitry Andric 
459*7a6dacacSDimitry Andric   auto &DL = M->getDataLayout();
460*7a6dacacSDimitry Andric   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
461*7a6dacacSDimitry Andric   IRBuilder<> IRB(BB);
462*7a6dacacSDimitry Andric 
463*7a6dacacSDimitry Andric   Type *RetTy = Arm64Ty->getReturnType();
464*7a6dacacSDimitry Andric   Type *X64RetType = X64Ty->getReturnType();
465*7a6dacacSDimitry Andric 
466*7a6dacacSDimitry Andric   bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
467*7a6dacacSDimitry Andric   unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
468*7a6dacacSDimitry Andric 
469*7a6dacacSDimitry Andric   // Translate arguments to call.
470*7a6dacacSDimitry Andric   SmallVector<Value *> Args;
471*7a6dacacSDimitry Andric   for (unsigned i = ThunkArgOffset, e = Thunk->arg_size(); i != e; ++i) {
472*7a6dacacSDimitry Andric     Value *Arg = Thunk->getArg(i);
473*7a6dacacSDimitry Andric     Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
474*7a6dacacSDimitry Andric     if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
475*7a6dacacSDimitry Andric         DL.getTypeStoreSize(ArgTy) > 8) {
476*7a6dacacSDimitry Andric       // Translate array/struct arguments to the expected type.
477*7a6dacacSDimitry Andric       if (DL.getTypeStoreSize(ArgTy) <= 8) {
478*7a6dacacSDimitry Andric         Value *CastAlloca = IRB.CreateAlloca(ArgTy);
479*7a6dacacSDimitry Andric         IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
480*7a6dacacSDimitry Andric         Arg = IRB.CreateLoad(ArgTy, CastAlloca);
481*7a6dacacSDimitry Andric       } else {
482*7a6dacacSDimitry Andric         Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
483*7a6dacacSDimitry Andric       }
484*7a6dacacSDimitry Andric     }
485*7a6dacacSDimitry Andric     Args.push_back(Arg);
486*7a6dacacSDimitry Andric   }
487*7a6dacacSDimitry Andric 
488*7a6dacacSDimitry Andric   // Call the function passed to the thunk.
489*7a6dacacSDimitry Andric   Value *Callee = Thunk->getArg(0);
490*7a6dacacSDimitry Andric   Callee = IRB.CreateBitCast(Callee, PtrTy);
491*7a6dacacSDimitry Andric   Value *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
492*7a6dacacSDimitry Andric 
493*7a6dacacSDimitry Andric   Value *RetVal = Call;
494*7a6dacacSDimitry Andric   if (TransformDirectToSRet) {
495*7a6dacacSDimitry Andric     IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
496*7a6dacacSDimitry Andric   } else if (X64RetType != RetTy) {
497*7a6dacacSDimitry Andric     Value *CastAlloca = IRB.CreateAlloca(X64RetType);
498*7a6dacacSDimitry Andric     IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
499*7a6dacacSDimitry Andric     RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
500*7a6dacacSDimitry Andric   }
501*7a6dacacSDimitry Andric 
502*7a6dacacSDimitry Andric   // Return to the caller.  Note that the isel has code to translate this
503*7a6dacacSDimitry Andric   // "ret" to a tail call to __os_arm64x_dispatch_ret.  (Alternatively, we
504*7a6dacacSDimitry Andric   // could emit a tail call here, but that would require a dedicated calling
505*7a6dacacSDimitry Andric   // convention, which seems more complicated overall.)
506*7a6dacacSDimitry Andric   if (X64RetType->isVoidTy())
507*7a6dacacSDimitry Andric     IRB.CreateRetVoid();
508*7a6dacacSDimitry Andric   else
509*7a6dacacSDimitry Andric     IRB.CreateRet(RetVal);
510*7a6dacacSDimitry Andric 
511*7a6dacacSDimitry Andric   return Thunk;
512*7a6dacacSDimitry Andric }
513*7a6dacacSDimitry Andric 
514*7a6dacacSDimitry Andric // Builds the "guest exit thunk", a helper to call a function which may or may
515*7a6dacacSDimitry Andric // not be an exit thunk. (We optimistically assume non-dllimport function
516*7a6dacacSDimitry Andric // declarations refer to functions defined in AArch64 code; if the linker
517*7a6dacacSDimitry Andric // can't prove that, we use this routine instead.)
518*7a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
519*7a6dacacSDimitry Andric   llvm::raw_null_ostream NullThunkName;
520*7a6dacacSDimitry Andric   FunctionType *Arm64Ty, *X64Ty;
521*7a6dacacSDimitry Andric   getThunkType(F->getFunctionType(), F->getAttributes(), /*EntryThunk*/ true,
522*7a6dacacSDimitry Andric                NullThunkName, Arm64Ty, X64Ty);
523*7a6dacacSDimitry Andric   auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
524*7a6dacacSDimitry Andric   assert(MangledName && "Can't guest exit to function that's already native");
525*7a6dacacSDimitry Andric   std::string ThunkName = *MangledName;
526*7a6dacacSDimitry Andric   if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
527*7a6dacacSDimitry Andric     ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
528*7a6dacacSDimitry Andric   } else {
529*7a6dacacSDimitry Andric     ThunkName.append("$exit_thunk");
530*7a6dacacSDimitry Andric   }
531*7a6dacacSDimitry Andric   Function *GuestExit =
532*7a6dacacSDimitry Andric       Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
533*7a6dacacSDimitry Andric   GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
534*7a6dacacSDimitry Andric   GuestExit->setSection(".wowthk$aa");
535*7a6dacacSDimitry Andric   GuestExit->setMetadata(
536*7a6dacacSDimitry Andric       "arm64ec_unmangled_name",
537*7a6dacacSDimitry Andric       MDNode::get(M->getContext(),
538*7a6dacacSDimitry Andric                   MDString::get(M->getContext(), F->getName())));
539*7a6dacacSDimitry Andric   GuestExit->setMetadata(
540*7a6dacacSDimitry Andric       "arm64ec_ecmangled_name",
541*7a6dacacSDimitry Andric       MDNode::get(M->getContext(),
542*7a6dacacSDimitry Andric                   MDString::get(M->getContext(), *MangledName)));
543*7a6dacacSDimitry Andric   F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
544*7a6dacacSDimitry Andric   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
545*7a6dacacSDimitry Andric   IRBuilder<> B(BB);
546*7a6dacacSDimitry Andric 
547*7a6dacacSDimitry Andric   // Load the global symbol as a pointer to the check function.
548*7a6dacacSDimitry Andric   Value *GuardFn;
549*7a6dacacSDimitry Andric   if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
550*7a6dacacSDimitry Andric     GuardFn = GuardFnCFGlobal;
551*7a6dacacSDimitry Andric   else
552*7a6dacacSDimitry Andric     GuardFn = GuardFnGlobal;
553*7a6dacacSDimitry Andric   LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
554*7a6dacacSDimitry Andric 
555*7a6dacacSDimitry Andric   // Create new call instruction. The CFGuard check should always be a call,
556*7a6dacacSDimitry Andric   // even if the original CallBase is an Invoke or CallBr instruction.
557*7a6dacacSDimitry Andric   Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
558*7a6dacacSDimitry Andric   CallInst *GuardCheck = B.CreateCall(
559*7a6dacacSDimitry Andric       GuardFnType, GuardCheckLoad,
560*7a6dacacSDimitry Andric       {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
561*7a6dacacSDimitry Andric 
562*7a6dacacSDimitry Andric   // Ensure that the first argument is passed in the correct register.
563*7a6dacacSDimitry Andric   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
564*7a6dacacSDimitry Andric 
565*7a6dacacSDimitry Andric   Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
566*7a6dacacSDimitry Andric   SmallVector<Value *> Args;
567*7a6dacacSDimitry Andric   for (Argument &Arg : GuestExit->args())
568*7a6dacacSDimitry Andric     Args.push_back(&Arg);
569*7a6dacacSDimitry Andric   CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
570*7a6dacacSDimitry Andric   Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
571*7a6dacacSDimitry Andric 
572*7a6dacacSDimitry Andric   if (Call->getType()->isVoidTy())
573*7a6dacacSDimitry Andric     B.CreateRetVoid();
574*7a6dacacSDimitry Andric   else
575*7a6dacacSDimitry Andric     B.CreateRet(Call);
576*7a6dacacSDimitry Andric 
577*7a6dacacSDimitry Andric   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
578*7a6dacacSDimitry Andric   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
579*7a6dacacSDimitry Andric   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
580*7a6dacacSDimitry Andric     GuestExit->addParamAttr(0, SRetAttr);
581*7a6dacacSDimitry Andric     Call->addParamAttr(0, SRetAttr);
582*7a6dacacSDimitry Andric   }
583*7a6dacacSDimitry Andric 
584*7a6dacacSDimitry Andric   return GuestExit;
585*7a6dacacSDimitry Andric }
586*7a6dacacSDimitry Andric 
587*7a6dacacSDimitry Andric // Lower an indirect call with inline code.
588*7a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
589*7a6dacacSDimitry Andric   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
590*7a6dacacSDimitry Andric          "Only applicable for Windows targets");
591*7a6dacacSDimitry Andric 
592*7a6dacacSDimitry Andric   IRBuilder<> B(CB);
593*7a6dacacSDimitry Andric   Value *CalledOperand = CB->getCalledOperand();
594*7a6dacacSDimitry Andric 
595*7a6dacacSDimitry Andric   // If the indirect call is called within catchpad or cleanuppad,
596*7a6dacacSDimitry Andric   // we need to copy "funclet" bundle of the call.
597*7a6dacacSDimitry Andric   SmallVector<llvm::OperandBundleDef, 1> Bundles;
598*7a6dacacSDimitry Andric   if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
599*7a6dacacSDimitry Andric     Bundles.push_back(OperandBundleDef(*Bundle));
600*7a6dacacSDimitry Andric 
601*7a6dacacSDimitry Andric   // Load the global symbol as a pointer to the check function.
602*7a6dacacSDimitry Andric   Value *GuardFn;
603*7a6dacacSDimitry Andric   if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
604*7a6dacacSDimitry Andric     GuardFn = GuardFnCFGlobal;
605*7a6dacacSDimitry Andric   else
606*7a6dacacSDimitry Andric     GuardFn = GuardFnGlobal;
607*7a6dacacSDimitry Andric   LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
608*7a6dacacSDimitry Andric 
609*7a6dacacSDimitry Andric   // Create new call instruction. The CFGuard check should always be a call,
610*7a6dacacSDimitry Andric   // even if the original CallBase is an Invoke or CallBr instruction.
611*7a6dacacSDimitry Andric   Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
612*7a6dacacSDimitry Andric   CallInst *GuardCheck =
613*7a6dacacSDimitry Andric       B.CreateCall(GuardFnType, GuardCheckLoad,
614*7a6dacacSDimitry Andric                    {B.CreateBitCast(CalledOperand, B.getPtrTy()),
615*7a6dacacSDimitry Andric                     B.CreateBitCast(Thunk, B.getPtrTy())},
616*7a6dacacSDimitry Andric                    Bundles);
617*7a6dacacSDimitry Andric 
618*7a6dacacSDimitry Andric   // Ensure that the first argument is passed in the correct register.
619*7a6dacacSDimitry Andric   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
620*7a6dacacSDimitry Andric 
621*7a6dacacSDimitry Andric   Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
622*7a6dacacSDimitry Andric   CB->setCalledOperand(GuardRetVal);
623*7a6dacacSDimitry Andric }
624*7a6dacacSDimitry Andric 
625*7a6dacacSDimitry Andric bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
626*7a6dacacSDimitry Andric   if (!GenerateThunks)
627*7a6dacacSDimitry Andric     return false;
628*7a6dacacSDimitry Andric 
629*7a6dacacSDimitry Andric   M = &Mod;
630*7a6dacacSDimitry Andric 
631*7a6dacacSDimitry Andric   // Check if this module has the cfguard flag and read its value.
632*7a6dacacSDimitry Andric   if (auto *MD =
633*7a6dacacSDimitry Andric           mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
634*7a6dacacSDimitry Andric     cfguard_module_flag = MD->getZExtValue();
635*7a6dacacSDimitry Andric 
636*7a6dacacSDimitry Andric   PtrTy = PointerType::getUnqual(M->getContext());
637*7a6dacacSDimitry Andric   I64Ty = Type::getInt64Ty(M->getContext());
638*7a6dacacSDimitry Andric   VoidTy = Type::getVoidTy(M->getContext());
639*7a6dacacSDimitry Andric 
640*7a6dacacSDimitry Andric   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
641*7a6dacacSDimitry Andric   GuardFnPtrType = PointerType::get(GuardFnType, 0);
642*7a6dacacSDimitry Andric   GuardFnCFGlobal =
643*7a6dacacSDimitry Andric       M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
644*7a6dacacSDimitry Andric   GuardFnGlobal =
645*7a6dacacSDimitry Andric       M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
646*7a6dacacSDimitry Andric 
647*7a6dacacSDimitry Andric   SetVector<Function *> DirectCalledFns;
648*7a6dacacSDimitry Andric   for (Function &F : Mod)
649*7a6dacacSDimitry Andric     if (!F.isDeclaration() &&
650*7a6dacacSDimitry Andric         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
651*7a6dacacSDimitry Andric         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
652*7a6dacacSDimitry Andric       processFunction(F, DirectCalledFns);
653*7a6dacacSDimitry Andric 
654*7a6dacacSDimitry Andric   struct ThunkInfo {
655*7a6dacacSDimitry Andric     Constant *Src;
656*7a6dacacSDimitry Andric     Constant *Dst;
657*7a6dacacSDimitry Andric     unsigned Kind;
658*7a6dacacSDimitry Andric   };
659*7a6dacacSDimitry Andric   SmallVector<ThunkInfo> ThunkMapping;
660*7a6dacacSDimitry Andric   for (Function &F : Mod) {
661*7a6dacacSDimitry Andric     if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
662*7a6dacacSDimitry Andric         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
663*7a6dacacSDimitry Andric         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
664*7a6dacacSDimitry Andric       if (!F.hasComdat())
665*7a6dacacSDimitry Andric         F.setComdat(Mod.getOrInsertComdat(F.getName()));
666*7a6dacacSDimitry Andric       ThunkMapping.push_back({&F, buildEntryThunk(&F), 1});
667*7a6dacacSDimitry Andric     }
668*7a6dacacSDimitry Andric   }
669*7a6dacacSDimitry Andric   for (Function *F : DirectCalledFns) {
670*7a6dacacSDimitry Andric     ThunkMapping.push_back(
671*7a6dacacSDimitry Andric         {F, buildExitThunk(F->getFunctionType(), F->getAttributes()), 4});
672*7a6dacacSDimitry Andric     if (!F->hasDLLImportStorageClass())
673*7a6dacacSDimitry Andric       ThunkMapping.push_back({buildGuestExitThunk(F), F, 0});
674*7a6dacacSDimitry Andric   }
675*7a6dacacSDimitry Andric 
676*7a6dacacSDimitry Andric   if (!ThunkMapping.empty()) {
677*7a6dacacSDimitry Andric     SmallVector<Constant *> ThunkMappingArrayElems;
678*7a6dacacSDimitry Andric     for (ThunkInfo &Thunk : ThunkMapping) {
679*7a6dacacSDimitry Andric       ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
680*7a6dacacSDimitry Andric           {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
681*7a6dacacSDimitry Andric            ConstantExpr::getBitCast(Thunk.Dst, PtrTy),
682*7a6dacacSDimitry Andric            ConstantInt::get(M->getContext(), APInt(32, Thunk.Kind))}));
683*7a6dacacSDimitry Andric     }
684*7a6dacacSDimitry Andric     Constant *ThunkMappingArray = ConstantArray::get(
685*7a6dacacSDimitry Andric         llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
686*7a6dacacSDimitry Andric                              ThunkMappingArrayElems.size()),
687*7a6dacacSDimitry Andric         ThunkMappingArrayElems);
688*7a6dacacSDimitry Andric     new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
689*7a6dacacSDimitry Andric                        GlobalValue::ExternalLinkage, ThunkMappingArray,
690*7a6dacacSDimitry Andric                        "llvm.arm64ec.symbolmap");
691*7a6dacacSDimitry Andric   }
692*7a6dacacSDimitry Andric 
693*7a6dacacSDimitry Andric   return true;
694*7a6dacacSDimitry Andric }
695*7a6dacacSDimitry Andric 
696*7a6dacacSDimitry Andric bool AArch64Arm64ECCallLowering::processFunction(
697*7a6dacacSDimitry Andric     Function &F, SetVector<Function *> &DirectCalledFns) {
698*7a6dacacSDimitry Andric   SmallVector<CallBase *, 8> IndirectCalls;
699*7a6dacacSDimitry Andric 
700*7a6dacacSDimitry Andric   // For ARM64EC targets, a function definition's name is mangled differently
701*7a6dacacSDimitry Andric   // from the normal symbol. We currently have no representation of this sort
702*7a6dacacSDimitry Andric   // of symbol in IR, so we change the name to the mangled name, then store
703*7a6dacacSDimitry Andric   // the unmangled name as metadata.  Later passes that need the unmangled
704*7a6dacacSDimitry Andric   // name (emitting the definition) can grab it from the metadata.
705*7a6dacacSDimitry Andric   //
706*7a6dacacSDimitry Andric   // FIXME: Handle functions with weak linkage?
707*7a6dacacSDimitry Andric   if (F.hasExternalLinkage() || F.hasWeakLinkage() || F.hasLinkOnceLinkage()) {
708*7a6dacacSDimitry Andric     if (std::optional<std::string> MangledName =
709*7a6dacacSDimitry Andric             getArm64ECMangledFunctionName(F.getName().str())) {
710*7a6dacacSDimitry Andric       F.setMetadata("arm64ec_unmangled_name",
711*7a6dacacSDimitry Andric                     MDNode::get(M->getContext(),
712*7a6dacacSDimitry Andric                                 MDString::get(M->getContext(), F.getName())));
713*7a6dacacSDimitry Andric       if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
714*7a6dacacSDimitry Andric         Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
715*7a6dacacSDimitry Andric         SmallVector<GlobalObject *> ComdatUsers =
716*7a6dacacSDimitry Andric             to_vector(F.getComdat()->getUsers());
717*7a6dacacSDimitry Andric         for (GlobalObject *User : ComdatUsers)
718*7a6dacacSDimitry Andric           User->setComdat(MangledComdat);
719*7a6dacacSDimitry Andric       }
720*7a6dacacSDimitry Andric       F.setName(MangledName.value());
721*7a6dacacSDimitry Andric     }
722*7a6dacacSDimitry Andric   }
723*7a6dacacSDimitry Andric 
724*7a6dacacSDimitry Andric   // Iterate over the instructions to find all indirect call/invoke/callbr
725*7a6dacacSDimitry Andric   // instructions. Make a separate list of pointers to indirect
726*7a6dacacSDimitry Andric   // call/invoke/callbr instructions because the original instructions will be
727*7a6dacacSDimitry Andric   // deleted as the checks are added.
728*7a6dacacSDimitry Andric   for (BasicBlock &BB : F) {
729*7a6dacacSDimitry Andric     for (Instruction &I : BB) {
730*7a6dacacSDimitry Andric       auto *CB = dyn_cast<CallBase>(&I);
731*7a6dacacSDimitry Andric       if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
732*7a6dacacSDimitry Andric           CB->isInlineAsm())
733*7a6dacacSDimitry Andric         continue;
734*7a6dacacSDimitry Andric 
735*7a6dacacSDimitry Andric       // We need to instrument any call that isn't directly calling an
736*7a6dacacSDimitry Andric       // ARM64 function.
737*7a6dacacSDimitry Andric       //
738*7a6dacacSDimitry Andric       // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
739*7a6dacacSDimitry Andric       // unprototyped functions in C)
740*7a6dacacSDimitry Andric       if (Function *F = CB->getCalledFunction()) {
741*7a6dacacSDimitry Andric         if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
742*7a6dacacSDimitry Andric             F->isIntrinsic() || !F->isDeclaration())
743*7a6dacacSDimitry Andric           continue;
744*7a6dacacSDimitry Andric 
745*7a6dacacSDimitry Andric         DirectCalledFns.insert(F);
746*7a6dacacSDimitry Andric         continue;
747*7a6dacacSDimitry Andric       }
748*7a6dacacSDimitry Andric 
749*7a6dacacSDimitry Andric       IndirectCalls.push_back(CB);
750*7a6dacacSDimitry Andric       ++Arm64ECCallsLowered;
751*7a6dacacSDimitry Andric     }
752*7a6dacacSDimitry Andric   }
753*7a6dacacSDimitry Andric 
754*7a6dacacSDimitry Andric   if (IndirectCalls.empty())
755*7a6dacacSDimitry Andric     return false;
756*7a6dacacSDimitry Andric 
757*7a6dacacSDimitry Andric   for (CallBase *CB : IndirectCalls)
758*7a6dacacSDimitry Andric     lowerCall(CB);
759*7a6dacacSDimitry Andric 
760*7a6dacacSDimitry Andric   return true;
761*7a6dacacSDimitry Andric }
762*7a6dacacSDimitry Andric 
763*7a6dacacSDimitry Andric char AArch64Arm64ECCallLowering::ID = 0;
764*7a6dacacSDimitry Andric INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
765*7a6dacacSDimitry Andric                 "AArch64Arm64ECCallLowering", false, false)
766*7a6dacacSDimitry Andric 
767*7a6dacacSDimitry Andric ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
768*7a6dacacSDimitry Andric   return new AArch64Arm64ECCallLowering;
769*7a6dacacSDimitry Andric }
770