xref: /llvm-project/llvm/lib/IR/AbstractCallSite.cpp (revision cd28a4736ab299f81a6bc74e8f22fb6d2b9375ed)
1 //===-- AbstractCallSite.cpp - Implementation of abstract call sites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements abstract call sites which unify the interface for
10 // direct, indirect, and callback call sites.
11 //
12 // For more information see:
13 // https://llvm.org/devmtg/2018-10/talk-abstracts.html#talk20
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/StringSwitch.h"
19 #include "llvm/IR/CallSite.h"
20 #include "llvm/Support/Debug.h"
21 
22 using namespace llvm;
23 
24 #define DEBUG_TYPE "abstract-call-sites"
25 
26 STATISTIC(NumCallbackCallSites, "Number of callback call sites created");
27 STATISTIC(NumDirectAbstractCallSites,
28           "Number of direct abstract call sites created");
29 STATISTIC(NumInvalidAbstractCallSitesUnknownUse,
30           "Number of invalid abstract call sites created (unknown use)");
31 STATISTIC(NumInvalidAbstractCallSitesUnknownCallee,
32           "Number of invalid abstract call sites created (unknown callee)");
33 STATISTIC(NumInvalidAbstractCallSitesNoCallback,
34           "Number of invalid abstract call sites created (no callback)");
35 
36 void AbstractCallSite::getCallbackUses(
37     const CallBase &CB, SmallVectorImpl<const Use *> &CallbackUses) {
38   const Function *Callee = CB.getCalledFunction();
39   if (!Callee)
40     return;
41 
42   MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
43   if (!CallbackMD)
44     return;
45 
46   for (const MDOperand &Op : CallbackMD->operands()) {
47     MDNode *OpMD = cast<MDNode>(Op.get());
48     auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
49     uint64_t CBCalleeIdx =
50         cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
51     if (CBCalleeIdx < CB.arg_size())
52       CallbackUses.push_back(CB.arg_begin() + CBCalleeIdx);
53   }
54 }
55 
56 /// Create an abstract call site from a use.
57 AbstractCallSite::AbstractCallSite(const Use *U)
58     : CB(dyn_cast<CallBase>(U->getUser())) {
59 
60   // First handle unknown users.
61   if (!CB) {
62 
63     // If the use is actually in a constant cast expression which itself
64     // has only one use, we look through the constant cast expression.
65     // This happens by updating the use @p U to the use of the constant
66     // cast expression and afterwards re-initializing CB accordingly.
67     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U->getUser()))
68       if (CE->getNumUses() == 1 && CE->isCast()) {
69         U = &*CE->use_begin();
70         CB = dyn_cast<CallBase>(U->getUser());
71       }
72 
73     if (!CB) {
74       NumInvalidAbstractCallSitesUnknownUse++;
75       return;
76     }
77   }
78 
79   // Then handle direct or indirect calls. Thus, if U is the callee of the
80   // call site CB it is not a callback and we are done.
81   if (CB->isCallee(U)) {
82     NumDirectAbstractCallSites++;
83     return;
84   }
85 
86   // If we cannot identify the broker function we cannot create a callback and
87   // invalidate the abstract call site.
88   Function *Callee = CB->getCalledFunction();
89   if (!Callee) {
90     NumInvalidAbstractCallSitesUnknownCallee++;
91     CB = nullptr;
92     return;
93   }
94 
95   MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
96   if (!CallbackMD) {
97     NumInvalidAbstractCallSitesNoCallback++;
98     CB = nullptr;
99     return;
100   }
101 
102   unsigned UseIdx = CB->getArgOperandNo(U);
103   MDNode *CallbackEncMD = nullptr;
104   for (const MDOperand &Op : CallbackMD->operands()) {
105     MDNode *OpMD = cast<MDNode>(Op.get());
106     auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
107     uint64_t CBCalleeIdx =
108         cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
109     if (CBCalleeIdx != UseIdx)
110       continue;
111     CallbackEncMD = OpMD;
112     break;
113   }
114 
115   if (!CallbackEncMD) {
116     NumInvalidAbstractCallSitesNoCallback++;
117     CB = nullptr;
118     return;
119   }
120 
121   NumCallbackCallSites++;
122 
123   assert(CallbackEncMD->getNumOperands() >= 2 && "Incomplete !callback metadata");
124 
125   unsigned NumCallOperands = CB->getNumArgOperands();
126   // Skip the var-arg flag at the end when reading the metadata.
127   for (unsigned u = 0, e = CallbackEncMD->getNumOperands() - 1; u < e; u++) {
128     Metadata *OpAsM = CallbackEncMD->getOperand(u).get();
129     auto *OpAsCM = cast<ConstantAsMetadata>(OpAsM);
130     assert(OpAsCM->getType()->isIntegerTy(64) &&
131            "Malformed !callback metadata");
132 
133     int64_t Idx = cast<ConstantInt>(OpAsCM->getValue())->getSExtValue();
134     assert(-1 <= Idx && Idx <= NumCallOperands &&
135            "Out-of-bounds !callback metadata index");
136 
137     CI.ParameterEncoding.push_back(Idx);
138   }
139 
140   if (!Callee->isVarArg())
141     return;
142 
143   Metadata *VarArgFlagAsM =
144       CallbackEncMD->getOperand(CallbackEncMD->getNumOperands() - 1).get();
145   auto *VarArgFlagAsCM = cast<ConstantAsMetadata>(VarArgFlagAsM);
146   assert(VarArgFlagAsCM->getType()->isIntegerTy(1) &&
147          "Malformed !callback metadata var-arg flag");
148 
149   if (VarArgFlagAsCM->getValue()->isNullValue())
150     return;
151 
152   // Add all variadic arguments at the end.
153   for (unsigned u = Callee->arg_size(); u < NumCallOperands; u++)
154     CI.ParameterEncoding.push_back(u);
155 }
156