xref: /llvm-project/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp (revision 21c9d9ad43f07c07a127bc6f29f13f62e831ab11)
1 //===-- lib/CodeGen/GlobalISel/CallLowering.cpp - Call lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements some simple delegations needed for call lowering.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/Analysis.h"
15 #include "llvm/CodeGen/GlobalISel/CallLowering.h"
16 #include "llvm/CodeGen/GlobalISel/Utils.h"
17 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
18 #include "llvm/CodeGen/MachineOperand.h"
19 #include "llvm/CodeGen/MachineRegisterInfo.h"
20 #include "llvm/CodeGen/TargetLowering.h"
21 #include "llvm/IR/DataLayout.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/LLVMContext.h"
24 #include "llvm/IR/Module.h"
25 
26 #define DEBUG_TYPE "call-lowering"
27 
28 using namespace llvm;
29 
30 void CallLowering::anchor() {}
31 
32 bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, ImmutableCallSite CS,
33                              ArrayRef<Register> ResRegs,
34                              ArrayRef<ArrayRef<Register>> ArgRegs,
35                              Register SwiftErrorVReg,
36                              std::function<unsigned()> GetCalleeReg) const {
37   CallLoweringInfo Info;
38   auto &DL = CS.getParent()->getParent()->getParent()->getDataLayout();
39 
40   // First step is to marshall all the function's parameters into the correct
41   // physregs and memory locations. Gather the sequence of argument types that
42   // we'll pass to the assigner function.
43   unsigned i = 0;
44   unsigned NumFixedArgs = CS.getFunctionType()->getNumParams();
45   for (auto &Arg : CS.args()) {
46     ArgInfo OrigArg{ArgRegs[i], Arg->getType(), ISD::ArgFlagsTy{},
47                     i < NumFixedArgs};
48     setArgFlags(OrigArg, i + AttributeList::FirstArgIndex, DL, CS);
49     Info.OrigArgs.push_back(OrigArg);
50     ++i;
51   }
52 
53   if (const Function *F = CS.getCalledFunction())
54     Info.Callee = MachineOperand::CreateGA(F, 0);
55   else {
56     // Try looking through a bitcast from one function type to another.
57     // Commonly happens with calls to objc_msgSend().
58     const Value *CalleeV = CS.getCalledValue();
59     auto *BC = dyn_cast<ConstantExpr>(CalleeV);
60     if (BC && BC->getOpcode() == Instruction::BitCast) {
61       if (const auto *F = dyn_cast<Function>(BC->getOperand(0))) {
62         Info.Callee = MachineOperand::CreateGA(F, 0);
63       }
64     } else {
65       Info.Callee = MachineOperand::CreateReg(GetCalleeReg(), false);
66     }
67   }
68 
69   Info.OrigRet = ArgInfo{ResRegs, CS.getType(), ISD::ArgFlagsTy{}};
70   if (!Info.OrigRet.Ty->isVoidTy())
71     setArgFlags(Info.OrigRet, AttributeList::ReturnIndex, DL, CS);
72 
73   Info.KnownCallees =
74       CS.getInstruction()->getMetadata(LLVMContext::MD_callees);
75   Info.CallConv = CS.getCallingConv();
76   Info.SwiftErrorVReg = SwiftErrorVReg;
77   Info.IsMustTailCall = CS.isMustTailCall();
78   Info.IsTailCall = CS.isTailCall() &&
79                     isInTailCallPosition(CS, MIRBuilder.getMF().getTarget()) &&
80                     (MIRBuilder.getMF()
81                          .getFunction()
82                          .getFnAttribute("disable-tail-calls")
83                          .getValueAsString() != "true");
84   Info.IsVarArg = CS.getFunctionType()->isVarArg();
85   return lowerCall(MIRBuilder, Info);
86 }
87 
88 template <typename FuncInfoTy>
89 void CallLowering::setArgFlags(CallLowering::ArgInfo &Arg, unsigned OpIdx,
90                                const DataLayout &DL,
91                                const FuncInfoTy &FuncInfo) const {
92   auto &Flags = Arg.Flags[0];
93   const AttributeList &Attrs = FuncInfo.getAttributes();
94   if (Attrs.hasAttribute(OpIdx, Attribute::ZExt))
95     Flags.setZExt();
96   if (Attrs.hasAttribute(OpIdx, Attribute::SExt))
97     Flags.setSExt();
98   if (Attrs.hasAttribute(OpIdx, Attribute::InReg))
99     Flags.setInReg();
100   if (Attrs.hasAttribute(OpIdx, Attribute::StructRet))
101     Flags.setSRet();
102   if (Attrs.hasAttribute(OpIdx, Attribute::SwiftSelf))
103     Flags.setSwiftSelf();
104   if (Attrs.hasAttribute(OpIdx, Attribute::SwiftError))
105     Flags.setSwiftError();
106   if (Attrs.hasAttribute(OpIdx, Attribute::ByVal))
107     Flags.setByVal();
108   if (Attrs.hasAttribute(OpIdx, Attribute::InAlloca))
109     Flags.setInAlloca();
110 
111   if (Flags.isByVal() || Flags.isInAlloca()) {
112     Type *ElementTy = cast<PointerType>(Arg.Ty)->getElementType();
113 
114     auto Ty = Attrs.getAttribute(OpIdx, Attribute::ByVal).getValueAsType();
115     Flags.setByValSize(DL.getTypeAllocSize(Ty ? Ty : ElementTy));
116 
117     // For ByVal, alignment should be passed from FE.  BE will guess if
118     // this info is not there but there are cases it cannot get right.
119     unsigned FrameAlign;
120     if (FuncInfo.getParamAlignment(OpIdx - 2))
121       FrameAlign = FuncInfo.getParamAlignment(OpIdx - 2);
122     else
123       FrameAlign = getTLI()->getByValTypeAlignment(ElementTy, DL);
124     Flags.setByValAlign(Align(FrameAlign));
125   }
126   if (Attrs.hasAttribute(OpIdx, Attribute::Nest))
127     Flags.setNest();
128   Flags.setOrigAlign(Align(DL.getABITypeAlignment(Arg.Ty)));
129 }
130 
131 template void
132 CallLowering::setArgFlags<Function>(CallLowering::ArgInfo &Arg, unsigned OpIdx,
133                                     const DataLayout &DL,
134                                     const Function &FuncInfo) const;
135 
136 template void
137 CallLowering::setArgFlags<CallInst>(CallLowering::ArgInfo &Arg, unsigned OpIdx,
138                                     const DataLayout &DL,
139                                     const CallInst &FuncInfo) const;
140 
141 Register CallLowering::packRegs(ArrayRef<Register> SrcRegs, Type *PackedTy,
142                                 MachineIRBuilder &MIRBuilder) const {
143   assert(SrcRegs.size() > 1 && "Nothing to pack");
144 
145   const DataLayout &DL = MIRBuilder.getMF().getDataLayout();
146   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
147 
148   LLT PackedLLT = getLLTForType(*PackedTy, DL);
149 
150   SmallVector<LLT, 8> LLTs;
151   SmallVector<uint64_t, 8> Offsets;
152   computeValueLLTs(DL, *PackedTy, LLTs, &Offsets);
153   assert(LLTs.size() == SrcRegs.size() && "Regs / types mismatch");
154 
155   Register Dst = MRI->createGenericVirtualRegister(PackedLLT);
156   MIRBuilder.buildUndef(Dst);
157   for (unsigned i = 0; i < SrcRegs.size(); ++i) {
158     Register NewDst = MRI->createGenericVirtualRegister(PackedLLT);
159     MIRBuilder.buildInsert(NewDst, Dst, SrcRegs[i], Offsets[i]);
160     Dst = NewDst;
161   }
162 
163   return Dst;
164 }
165 
166 void CallLowering::unpackRegs(ArrayRef<Register> DstRegs, Register SrcReg,
167                               Type *PackedTy,
168                               MachineIRBuilder &MIRBuilder) const {
169   assert(DstRegs.size() > 1 && "Nothing to unpack");
170 
171   const DataLayout &DL = MIRBuilder.getMF().getDataLayout();
172 
173   SmallVector<LLT, 8> LLTs;
174   SmallVector<uint64_t, 8> Offsets;
175   computeValueLLTs(DL, *PackedTy, LLTs, &Offsets);
176   assert(LLTs.size() == DstRegs.size() && "Regs / types mismatch");
177 
178   for (unsigned i = 0; i < DstRegs.size(); ++i)
179     MIRBuilder.buildExtract(DstRegs[i], SrcReg, Offsets[i]);
180 }
181 
182 bool CallLowering::handleAssignments(MachineIRBuilder &MIRBuilder,
183                                      SmallVectorImpl<ArgInfo> &Args,
184                                      ValueHandler &Handler) const {
185   MachineFunction &MF = MIRBuilder.getMF();
186   const Function &F = MF.getFunction();
187   SmallVector<CCValAssign, 16> ArgLocs;
188   CCState CCInfo(F.getCallingConv(), F.isVarArg(), MF, ArgLocs, F.getContext());
189   return handleAssignments(CCInfo, ArgLocs, MIRBuilder, Args, Handler);
190 }
191 
192 bool CallLowering::handleAssignments(CCState &CCInfo,
193                                      SmallVectorImpl<CCValAssign> &ArgLocs,
194                                      MachineIRBuilder &MIRBuilder,
195                                      SmallVectorImpl<ArgInfo> &Args,
196                                      ValueHandler &Handler) const {
197   MachineFunction &MF = MIRBuilder.getMF();
198   const Function &F = MF.getFunction();
199   const DataLayout &DL = F.getParent()->getDataLayout();
200 
201   unsigned NumArgs = Args.size();
202   for (unsigned i = 0; i != NumArgs; ++i) {
203     MVT CurVT = MVT::getVT(Args[i].Ty);
204     if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i],
205                           Args[i].Flags[0], CCInfo)) {
206       if (!CurVT.isValid())
207         return false;
208       MVT NewVT = TLI->getRegisterTypeForCallingConv(
209           F.getContext(), F.getCallingConv(), EVT(CurVT));
210 
211       // If we need to split the type over multiple regs, check it's a scenario
212       // we currently support.
213       unsigned NumParts = TLI->getNumRegistersForCallingConv(
214           F.getContext(), F.getCallingConv(), CurVT);
215       if (NumParts > 1) {
216         // For now only handle exact splits.
217         if (NewVT.getSizeInBits() * NumParts != CurVT.getSizeInBits())
218           return false;
219       }
220 
221       // For incoming arguments (physregs to vregs), we could have values in
222       // physregs (or memlocs) which we want to extract and copy to vregs.
223       // During this, we might have to deal with the LLT being split across
224       // multiple regs, so we have to record this information for later.
225       //
226       // If we have outgoing args, then we have the opposite case. We have a
227       // vreg with an LLT which we want to assign to a physical location, and
228       // we might have to record that the value has to be split later.
229       if (Handler.isIncomingArgumentHandler()) {
230         if (NumParts == 1) {
231           // Try to use the register type if we couldn't assign the VT.
232           if (Handler.assignArg(i, NewVT, NewVT, CCValAssign::Full, Args[i],
233                                 Args[i].Flags[0], CCInfo))
234             return false;
235         } else {
236           // We're handling an incoming arg which is split over multiple regs.
237           // E.g. passing an s128 on AArch64.
238           ISD::ArgFlagsTy OrigFlags = Args[i].Flags[0];
239           Args[i].OrigRegs.push_back(Args[i].Regs[0]);
240           Args[i].Regs.clear();
241           Args[i].Flags.clear();
242           LLT NewLLT = getLLTForMVT(NewVT);
243           // For each split register, create and assign a vreg that will store
244           // the incoming component of the larger value. These will later be
245           // merged to form the final vreg.
246           for (unsigned Part = 0; Part < NumParts; ++Part) {
247             Register Reg =
248                 MIRBuilder.getMRI()->createGenericVirtualRegister(NewLLT);
249             ISD::ArgFlagsTy Flags = OrigFlags;
250             if (Part == 0) {
251               Flags.setSplit();
252             } else {
253               Flags.setOrigAlign(Align(1));
254               if (Part == NumParts - 1)
255                 Flags.setSplitEnd();
256             }
257             Args[i].Regs.push_back(Reg);
258             Args[i].Flags.push_back(Flags);
259             if (Handler.assignArg(i + Part, NewVT, NewVT, CCValAssign::Full,
260                                   Args[i], Args[i].Flags[Part], CCInfo)) {
261               // Still couldn't assign this smaller part type for some reason.
262               return false;
263             }
264           }
265         }
266       } else {
267         // Handling an outgoing arg that might need to be split.
268         if (NumParts < 2)
269           return false; // Don't know how to deal with this type combination.
270 
271         // This type is passed via multiple registers in the calling convention.
272         // We need to extract the individual parts.
273         Register LargeReg = Args[i].Regs[0];
274         LLT SmallTy = LLT::scalar(NewVT.getSizeInBits());
275         auto Unmerge = MIRBuilder.buildUnmerge(SmallTy, LargeReg);
276         assert(Unmerge->getNumOperands() == NumParts + 1);
277         ISD::ArgFlagsTy OrigFlags = Args[i].Flags[0];
278         // We're going to replace the regs and flags with the split ones.
279         Args[i].Regs.clear();
280         Args[i].Flags.clear();
281         for (unsigned PartIdx = 0; PartIdx < NumParts; ++PartIdx) {
282           ISD::ArgFlagsTy Flags = OrigFlags;
283           if (PartIdx == 0) {
284             Flags.setSplit();
285           } else {
286             Flags.setOrigAlign(Align(1));
287             if (PartIdx == NumParts - 1)
288               Flags.setSplitEnd();
289           }
290           Args[i].Regs.push_back(Unmerge.getReg(PartIdx));
291           Args[i].Flags.push_back(Flags);
292           if (Handler.assignArg(i + PartIdx, NewVT, NewVT, CCValAssign::Full,
293                                 Args[i], Args[i].Flags[PartIdx], CCInfo))
294             return false;
295         }
296       }
297     }
298   }
299 
300   for (unsigned i = 0, e = Args.size(), j = 0; i != e; ++i, ++j) {
301     assert(j < ArgLocs.size() && "Skipped too many arg locs");
302 
303     CCValAssign &VA = ArgLocs[j];
304     assert(VA.getValNo() == i && "Location doesn't correspond to current arg");
305 
306     if (VA.needsCustom()) {
307       j += Handler.assignCustomValue(Args[i], makeArrayRef(ArgLocs).slice(j));
308       continue;
309     }
310 
311     // FIXME: Pack registers if we have more than one.
312     Register ArgReg = Args[i].Regs[0];
313 
314     MVT OrigVT = MVT::getVT(Args[i].Ty);
315     MVT VAVT = VA.getValVT();
316     if (VA.isRegLoc()) {
317       if (Handler.isIncomingArgumentHandler() && VAVT != OrigVT) {
318         if (VAVT.getSizeInBits() < OrigVT.getSizeInBits()) {
319           // Expected to be multiple regs for a single incoming arg.
320           unsigned NumArgRegs = Args[i].Regs.size();
321           if (NumArgRegs < 2)
322             return false;
323 
324           assert((j + (NumArgRegs - 1)) < ArgLocs.size() &&
325                  "Too many regs for number of args");
326           for (unsigned Part = 0; Part < NumArgRegs; ++Part) {
327             // There should be Regs.size() ArgLocs per argument.
328             VA = ArgLocs[j + Part];
329             Handler.assignValueToReg(Args[i].Regs[Part], VA.getLocReg(), VA);
330           }
331           j += NumArgRegs - 1;
332           // Merge the split registers into the expected larger result vreg
333           // of the original call.
334           MIRBuilder.buildMerge(Args[i].OrigRegs[0], Args[i].Regs);
335           continue;
336         }
337         const LLT VATy(VAVT);
338         Register NewReg =
339             MIRBuilder.getMRI()->createGenericVirtualRegister(VATy);
340         Handler.assignValueToReg(NewReg, VA.getLocReg(), VA);
341         // If it's a vector type, we either need to truncate the elements
342         // or do an unmerge to get the lower block of elements.
343         if (VATy.isVector() &&
344             VATy.getNumElements() > OrigVT.getVectorNumElements()) {
345           const LLT OrigTy(OrigVT);
346           // Just handle the case where the VA type is 2 * original type.
347           if (VATy.getNumElements() != OrigVT.getVectorNumElements() * 2) {
348             LLVM_DEBUG(dbgs()
349                        << "Incoming promoted vector arg has too many elts");
350             return false;
351           }
352           auto Unmerge = MIRBuilder.buildUnmerge({OrigTy, OrigTy}, {NewReg});
353           MIRBuilder.buildCopy(ArgReg, Unmerge.getReg(0));
354         } else {
355           MIRBuilder.buildTrunc(ArgReg, {NewReg}).getReg(0);
356         }
357       } else if (!Handler.isIncomingArgumentHandler()) {
358         assert((j + (Args[i].Regs.size() - 1)) < ArgLocs.size() &&
359                "Too many regs for number of args");
360         // This is an outgoing argument that might have been split.
361         for (unsigned Part = 0; Part < Args[i].Regs.size(); ++Part) {
362           // There should be Regs.size() ArgLocs per argument.
363           VA = ArgLocs[j + Part];
364           Handler.assignValueToReg(Args[i].Regs[Part], VA.getLocReg(), VA);
365         }
366         j += Args[i].Regs.size() - 1;
367       } else {
368         Handler.assignValueToReg(ArgReg, VA.getLocReg(), VA);
369       }
370     } else if (VA.isMemLoc()) {
371       // Don't currently support loading/storing a type that needs to be split
372       // to the stack. Should be easy, just not implemented yet.
373       if (Args[i].Regs.size() > 1) {
374         LLVM_DEBUG(
375             dbgs()
376             << "Load/store a split arg to/from the stack not implemented yet");
377         return false;
378       }
379       MVT VT = MVT::getVT(Args[i].Ty);
380       unsigned Size = VT == MVT::iPTR ? DL.getPointerSize()
381                                       : alignTo(VT.getSizeInBits(), 8) / 8;
382       unsigned Offset = VA.getLocMemOffset();
383       MachinePointerInfo MPO;
384       Register StackAddr = Handler.getStackAddress(Size, Offset, MPO);
385       Handler.assignValueToAddress(ArgReg, StackAddr, Size, MPO, VA);
386     } else {
387       // FIXME: Support byvals and other weirdness
388       return false;
389     }
390   }
391   return true;
392 }
393 
394 bool CallLowering::analyzeArgInfo(CCState &CCState,
395                                   SmallVectorImpl<ArgInfo> &Args,
396                                   CCAssignFn &AssignFnFixed,
397                                   CCAssignFn &AssignFnVarArg) const {
398   for (unsigned i = 0, e = Args.size(); i < e; ++i) {
399     MVT VT = MVT::getVT(Args[i].Ty);
400     CCAssignFn &Fn = Args[i].IsFixed ? AssignFnFixed : AssignFnVarArg;
401     if (Fn(i, VT, VT, CCValAssign::Full, Args[i].Flags[0], CCState)) {
402       // Bail out on anything we can't handle.
403       LLVM_DEBUG(dbgs() << "Cannot analyze " << EVT(VT).getEVTString()
404                         << " (arg number = " << i << "\n");
405       return false;
406     }
407   }
408   return true;
409 }
410 
411 bool CallLowering::resultsCompatible(CallLoweringInfo &Info,
412                                      MachineFunction &MF,
413                                      SmallVectorImpl<ArgInfo> &InArgs,
414                                      CCAssignFn &CalleeAssignFnFixed,
415                                      CCAssignFn &CalleeAssignFnVarArg,
416                                      CCAssignFn &CallerAssignFnFixed,
417                                      CCAssignFn &CallerAssignFnVarArg) const {
418   const Function &F = MF.getFunction();
419   CallingConv::ID CalleeCC = Info.CallConv;
420   CallingConv::ID CallerCC = F.getCallingConv();
421 
422   if (CallerCC == CalleeCC)
423     return true;
424 
425   SmallVector<CCValAssign, 16> ArgLocs1;
426   CCState CCInfo1(CalleeCC, false, MF, ArgLocs1, F.getContext());
427   if (!analyzeArgInfo(CCInfo1, InArgs, CalleeAssignFnFixed,
428                       CalleeAssignFnVarArg))
429     return false;
430 
431   SmallVector<CCValAssign, 16> ArgLocs2;
432   CCState CCInfo2(CallerCC, false, MF, ArgLocs2, F.getContext());
433   if (!analyzeArgInfo(CCInfo2, InArgs, CallerAssignFnFixed,
434                       CalleeAssignFnVarArg))
435     return false;
436 
437   // We need the argument locations to match up exactly. If there's more in
438   // one than the other, then we are done.
439   if (ArgLocs1.size() != ArgLocs2.size())
440     return false;
441 
442   // Make sure that each location is passed in exactly the same way.
443   for (unsigned i = 0, e = ArgLocs1.size(); i < e; ++i) {
444     const CCValAssign &Loc1 = ArgLocs1[i];
445     const CCValAssign &Loc2 = ArgLocs2[i];
446 
447     // We need both of them to be the same. So if one is a register and one
448     // isn't, we're done.
449     if (Loc1.isRegLoc() != Loc2.isRegLoc())
450       return false;
451 
452     if (Loc1.isRegLoc()) {
453       // If they don't have the same register location, we're done.
454       if (Loc1.getLocReg() != Loc2.getLocReg())
455         return false;
456 
457       // They matched, so we can move to the next ArgLoc.
458       continue;
459     }
460 
461     // Loc1 wasn't a RegLoc, so they both must be MemLocs. Check if they match.
462     if (Loc1.getLocMemOffset() != Loc2.getLocMemOffset())
463       return false;
464   }
465 
466   return true;
467 }
468 
469 Register CallLowering::ValueHandler::extendRegister(Register ValReg,
470                                                     CCValAssign &VA) {
471   LLT LocTy{VA.getLocVT()};
472   if (LocTy.getSizeInBits() == MRI.getType(ValReg).getSizeInBits())
473     return ValReg;
474   switch (VA.getLocInfo()) {
475   default: break;
476   case CCValAssign::Full:
477   case CCValAssign::BCvt:
478     // FIXME: bitconverting between vector types may or may not be a
479     // nop in big-endian situations.
480     return ValReg;
481   case CCValAssign::AExt: {
482     auto MIB = MIRBuilder.buildAnyExt(LocTy, ValReg);
483     return MIB.getReg(0);
484   }
485   case CCValAssign::SExt: {
486     Register NewReg = MRI.createGenericVirtualRegister(LocTy);
487     MIRBuilder.buildSExt(NewReg, ValReg);
488     return NewReg;
489   }
490   case CCValAssign::ZExt: {
491     Register NewReg = MRI.createGenericVirtualRegister(LocTy);
492     MIRBuilder.buildZExt(NewReg, ValReg);
493     return NewReg;
494   }
495   }
496   llvm_unreachable("unable to extend register");
497 }
498 
499 void CallLowering::ValueHandler::anchor() {}
500