xref: /llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyAddMissingPrototypes.cpp (revision 43570a2841e2a8f1efd00503beee751cc1e72513)
1 //===-- WebAssemblyAddMissingPrototypes.cpp - Fix prototypeless functions -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Add prototypes to prototypes-less functions.
11 ///
12 /// WebAssembly has strict function prototype checking so we need functions
13 /// declarations to match the call sites.  Clang treats prototype-less functions
14 /// as varargs (foo(...)) which happens to work on existing platforms but
15 /// doesn't under WebAssembly.  This pass will find all the call sites of each
16 /// prototype-less function, ensure they agree, and then set the signature
17 /// on the function declaration accordingly.
18 ///
19 //===----------------------------------------------------------------------===//
20 
21 #include "WebAssembly.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/IR/Operator.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Transforms/Utils/Local.h"
28 #include "llvm/Transforms/Utils/ModuleUtils.h"
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "wasm-add-missing-prototypes"
32 
33 namespace {
34 class WebAssemblyAddMissingPrototypes final : public ModulePass {
35   StringRef getPassName() const override {
36     return "Add prototypes to prototypes-less functions";
37   }
38 
39   void getAnalysisUsage(AnalysisUsage &AU) const override {
40     AU.setPreservesCFG();
41     ModulePass::getAnalysisUsage(AU);
42   }
43 
44   bool runOnModule(Module &M) override;
45 
46 public:
47   static char ID;
48   WebAssemblyAddMissingPrototypes() : ModulePass(ID) {}
49 };
50 } // End anonymous namespace
51 
52 char WebAssemblyAddMissingPrototypes::ID = 0;
53 INITIALIZE_PASS(WebAssemblyAddMissingPrototypes, DEBUG_TYPE,
54                 "Add prototypes to prototypes-less functions", false, false)
55 
56 ModulePass *llvm::createWebAssemblyAddMissingPrototypes() {
57   return new WebAssemblyAddMissingPrototypes();
58 }
59 
60 bool WebAssemblyAddMissingPrototypes::runOnModule(Module &M) {
61   LLVM_DEBUG(dbgs() << "********** Add Missing Prototypes **********\n");
62 
63   std::vector<std::pair<Function *, Function *>> Replacements;
64 
65   // Find all the prototype-less function declarations
66   for (Function &F : M) {
67     if (!F.isDeclaration() || !F.hasFnAttribute("no-prototype"))
68       continue;
69 
70     LLVM_DEBUG(dbgs() << "Found no-prototype function: " << F.getName()
71                       << "\n");
72 
73     // When clang emits prototype-less C functions it uses (...), i.e. varargs
74     // function that take no arguments (have no sentinel).  When we see a
75     // no-prototype attribute we expect the function have these properties.
76     if (!F.isVarArg())
77       report_fatal_error(
78           "Functions with 'no-prototype' attribute must take varargs: " +
79           F.getName());
80     unsigned NumParams = F.getFunctionType()->getNumParams();
81     if (NumParams != 0) {
82       if (!(NumParams == 1 && F.arg_begin()->hasStructRetAttr()))
83         report_fatal_error("Functions with 'no-prototype' attribute should "
84                            "not have params: " +
85                            F.getName());
86     }
87 
88     // Find calls of this function, looking through bitcasts.
89     SmallVector<CallBase *> Calls;
90     SmallVector<Value *> Worklist;
91     Worklist.push_back(&F);
92     while (!Worklist.empty()) {
93       Value *V = Worklist.pop_back_val();
94       for (User *U : V->users()) {
95         if (auto *BC = dyn_cast<BitCastOperator>(U))
96           Worklist.push_back(BC);
97         else if (auto *CB = dyn_cast<CallBase>(U))
98           if (CB->getCalledOperand() == V)
99             Calls.push_back(CB);
100       }
101     }
102 
103     // Create a function prototype based on the first call site that we find.
104     FunctionType *NewType = nullptr;
105     for (CallBase *CB : Calls) {
106       LLVM_DEBUG(dbgs() << "prototype-less call of " << F.getName() << ":\n");
107       LLVM_DEBUG(dbgs() << *CB << "\n");
108       FunctionType *DestType = CB->getFunctionType();
109       if (!NewType) {
110         // Create a new function with the correct type
111         NewType = DestType;
112         LLVM_DEBUG(dbgs() << "found function type: " << *NewType << "\n");
113       } else if (NewType != DestType) {
114         errs() << "warning: prototype-less function used with "
115                   "conflicting signatures: "
116                << F.getName() << "\n";
117         LLVM_DEBUG(dbgs() << "  " << *DestType << "\n");
118         LLVM_DEBUG(dbgs() << "  " << *NewType << "\n");
119       }
120     }
121 
122     if (!NewType) {
123       LLVM_DEBUG(
124           dbgs() << "could not derive a function prototype from usage: " +
125                         F.getName() + "\n");
126       // We could not derive a type for this function.  In this case strip
127       // the isVarArg and make it a simple zero-arg function.  This has more
128       // chance of being correct.  The current signature of (...) is illegal in
129       // C since it doesn't have any arguments before the "...", we this at
130       // least makes it possible for this symbol to be resolved by the linker.
131       NewType = FunctionType::get(F.getFunctionType()->getReturnType(), false);
132     }
133 
134     Function *NewF =
135         Function::Create(NewType, F.getLinkage(), F.getName() + ".fixed_sig");
136     NewF->setAttributes(F.getAttributes());
137     NewF->removeFnAttr("no-prototype");
138     NewF->IsNewDbgInfoFormat = F.IsNewDbgInfoFormat;
139     Replacements.emplace_back(&F, NewF);
140   }
141 
142   for (auto &Pair : Replacements) {
143     Function *OldF = Pair.first;
144     Function *NewF = Pair.second;
145     std::string Name = std::string(OldF->getName());
146     M.getFunctionList().push_back(NewF);
147     OldF->replaceAllUsesWith(
148         ConstantExpr::getPointerBitCastOrAddrSpaceCast(NewF, OldF->getType()));
149     OldF->eraseFromParent();
150     NewF->setName(Name);
151   }
152 
153   return !Replacements.empty();
154 }
155