xref: /openbsd-src/gnu/llvm/clang/lib/Analysis/UnsafeBufferUsage.cpp (revision 12c855180aad702bbcca06e0398d774beeafb155)
1*12c85518Srobert //===- UnsafeBufferUsage.cpp - Replace pointers with modern C++ -----------===//
2*12c85518Srobert //
3*12c85518Srobert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*12c85518Srobert // See https://llvm.org/LICENSE.txt for license information.
5*12c85518Srobert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*12c85518Srobert //
7*12c85518Srobert //===----------------------------------------------------------------------===//
8*12c85518Srobert 
9*12c85518Srobert #include "clang/Analysis/Analyses/UnsafeBufferUsage.h"
10*12c85518Srobert #include "clang/AST/RecursiveASTVisitor.h"
11*12c85518Srobert #include "clang/ASTMatchers/ASTMatchFinder.h"
12*12c85518Srobert #include "llvm/ADT/SmallVector.h"
13*12c85518Srobert #include <memory>
14*12c85518Srobert #include <optional>
15*12c85518Srobert 
16*12c85518Srobert using namespace llvm;
17*12c85518Srobert using namespace clang;
18*12c85518Srobert using namespace ast_matchers;
19*12c85518Srobert 
20*12c85518Srobert namespace clang::ast_matchers {
21*12c85518Srobert // A `RecursiveASTVisitor` that traverses all descendants of a given node "n"
22*12c85518Srobert // except for those belonging to a different callable of "n".
23*12c85518Srobert class MatchDescendantVisitor
24*12c85518Srobert     : public RecursiveASTVisitor<MatchDescendantVisitor> {
25*12c85518Srobert public:
26*12c85518Srobert   typedef RecursiveASTVisitor<MatchDescendantVisitor> VisitorBase;
27*12c85518Srobert 
28*12c85518Srobert   // Creates an AST visitor that matches `Matcher` on all
29*12c85518Srobert   // descendants of a given node "n" except for the ones
30*12c85518Srobert   // belonging to a different callable of "n".
MatchDescendantVisitor(const internal::DynTypedMatcher * Matcher,internal::ASTMatchFinder * Finder,internal::BoundNodesTreeBuilder * Builder,internal::ASTMatchFinder::BindKind Bind)31*12c85518Srobert   MatchDescendantVisitor(const internal::DynTypedMatcher *Matcher,
32*12c85518Srobert                          internal::ASTMatchFinder *Finder,
33*12c85518Srobert                          internal::BoundNodesTreeBuilder *Builder,
34*12c85518Srobert                          internal::ASTMatchFinder::BindKind Bind)
35*12c85518Srobert       : Matcher(Matcher), Finder(Finder), Builder(Builder), Bind(Bind),
36*12c85518Srobert         Matches(false) {}
37*12c85518Srobert 
38*12c85518Srobert   // Returns true if a match is found in a subtree of `DynNode`, which belongs
39*12c85518Srobert   // to the same callable of `DynNode`.
findMatch(const DynTypedNode & DynNode)40*12c85518Srobert   bool findMatch(const DynTypedNode &DynNode) {
41*12c85518Srobert     Matches = false;
42*12c85518Srobert     if (const Stmt *StmtNode = DynNode.get<Stmt>()) {
43*12c85518Srobert       TraverseStmt(const_cast<Stmt *>(StmtNode));
44*12c85518Srobert       *Builder = ResultBindings;
45*12c85518Srobert       return Matches;
46*12c85518Srobert     }
47*12c85518Srobert     return false;
48*12c85518Srobert   }
49*12c85518Srobert 
50*12c85518Srobert   // The following are overriding methods from the base visitor class.
51*12c85518Srobert   // They are public only to allow CRTP to work. They are *not *part
52*12c85518Srobert   // of the public API of this class.
53*12c85518Srobert 
54*12c85518Srobert   // For the matchers so far used in safe buffers, we only need to match
55*12c85518Srobert   // `Stmt`s.  To override more as needed.
56*12c85518Srobert 
TraverseDecl(Decl * Node)57*12c85518Srobert   bool TraverseDecl(Decl *Node) {
58*12c85518Srobert     if (!Node)
59*12c85518Srobert       return true;
60*12c85518Srobert     if (!match(*Node))
61*12c85518Srobert       return false;
62*12c85518Srobert     // To skip callables:
63*12c85518Srobert     if (isa<FunctionDecl, BlockDecl, ObjCMethodDecl>(Node))
64*12c85518Srobert       return true;
65*12c85518Srobert     // Traverse descendants
66*12c85518Srobert     return VisitorBase::TraverseDecl(Node);
67*12c85518Srobert   }
68*12c85518Srobert 
TraverseStmt(Stmt * Node,DataRecursionQueue * Queue=nullptr)69*12c85518Srobert   bool TraverseStmt(Stmt *Node, DataRecursionQueue *Queue = nullptr) {
70*12c85518Srobert     if (!Node)
71*12c85518Srobert       return true;
72*12c85518Srobert     if (!match(*Node))
73*12c85518Srobert       return false;
74*12c85518Srobert     // To skip callables:
75*12c85518Srobert     if (isa<LambdaExpr>(Node))
76*12c85518Srobert       return true;
77*12c85518Srobert     return VisitorBase::TraverseStmt(Node);
78*12c85518Srobert   }
79*12c85518Srobert 
shouldVisitTemplateInstantiations() const80*12c85518Srobert   bool shouldVisitTemplateInstantiations() const { return true; }
shouldVisitImplicitCode() const81*12c85518Srobert   bool shouldVisitImplicitCode() const {
82*12c85518Srobert     // TODO: let's ignore implicit code for now
83*12c85518Srobert     return false;
84*12c85518Srobert   }
85*12c85518Srobert 
86*12c85518Srobert private:
87*12c85518Srobert   // Sets 'Matched' to true if 'Matcher' matches 'Node'
88*12c85518Srobert   //
89*12c85518Srobert   // Returns 'true' if traversal should continue after this function
90*12c85518Srobert   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
match(const T & Node)91*12c85518Srobert   template <typename T> bool match(const T &Node) {
92*12c85518Srobert     internal::BoundNodesTreeBuilder RecursiveBuilder(*Builder);
93*12c85518Srobert 
94*12c85518Srobert     if (Matcher->matches(DynTypedNode::create(Node), Finder,
95*12c85518Srobert                          &RecursiveBuilder)) {
96*12c85518Srobert       ResultBindings.addMatch(RecursiveBuilder);
97*12c85518Srobert       Matches = true;
98*12c85518Srobert       if (Bind != internal::ASTMatchFinder::BK_All)
99*12c85518Srobert         return false; // Abort as soon as a match is found.
100*12c85518Srobert     }
101*12c85518Srobert     return true;
102*12c85518Srobert   }
103*12c85518Srobert 
104*12c85518Srobert   const internal::DynTypedMatcher *const Matcher;
105*12c85518Srobert   internal::ASTMatchFinder *const Finder;
106*12c85518Srobert   internal::BoundNodesTreeBuilder *const Builder;
107*12c85518Srobert   internal::BoundNodesTreeBuilder ResultBindings;
108*12c85518Srobert   const internal::ASTMatchFinder::BindKind Bind;
109*12c85518Srobert   bool Matches;
110*12c85518Srobert };
111*12c85518Srobert 
AST_MATCHER_P(Stmt,forEveryDescendant,internal::Matcher<Stmt>,innerMatcher)112*12c85518Srobert AST_MATCHER_P(Stmt, forEveryDescendant, internal::Matcher<Stmt>, innerMatcher) {
113*12c85518Srobert   const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher);
114*12c85518Srobert 
115*12c85518Srobert   MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All);
116*12c85518Srobert   return Visitor.findMatch(DynTypedNode::create(Node));
117*12c85518Srobert }
118*12c85518Srobert } // namespace clang::ast_matchers
119*12c85518Srobert 
120*12c85518Srobert namespace {
121*12c85518Srobert // Because the analysis revolves around variables and their types, we'll need to
122*12c85518Srobert // track uses of variables (aka DeclRefExprs).
123*12c85518Srobert using DeclUseList = SmallVector<const DeclRefExpr *, 1>;
124*12c85518Srobert 
125*12c85518Srobert // Convenience typedef.
126*12c85518Srobert using FixItList = SmallVector<FixItHint, 4>;
127*12c85518Srobert 
128*12c85518Srobert // Defined below.
129*12c85518Srobert class Strategy;
130*12c85518Srobert } // namespace
131*12c85518Srobert 
132*12c85518Srobert // Because we're dealing with raw pointers, let's define what we mean by that.
hasPointerType()133*12c85518Srobert static auto hasPointerType() {
134*12c85518Srobert     return hasType(hasCanonicalType(pointerType()));
135*12c85518Srobert }
136*12c85518Srobert 
hasArrayType()137*12c85518Srobert static auto hasArrayType() {
138*12c85518Srobert     return hasType(hasCanonicalType(arrayType()));
139*12c85518Srobert }
140*12c85518Srobert 
141*12c85518Srobert namespace {
142*12c85518Srobert /// Gadget is an individual operation in the code that may be of interest to
143*12c85518Srobert /// this analysis. Each (non-abstract) subclass corresponds to a specific
144*12c85518Srobert /// rigid AST structure that constitutes an operation on a pointer-type object.
145*12c85518Srobert /// Discovery of a gadget in the code corresponds to claiming that we understand
146*12c85518Srobert /// what this part of code is doing well enough to potentially improve it.
147*12c85518Srobert /// Gadgets can be warning (immediately deserving a warning) or fixable (not
148*12c85518Srobert /// always deserving a warning per se, but requires our attention to identify
149*12c85518Srobert /// it warrants a fixit).
150*12c85518Srobert class Gadget {
151*12c85518Srobert public:
152*12c85518Srobert   enum class Kind {
153*12c85518Srobert #define GADGET(x) x,
154*12c85518Srobert #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
155*12c85518Srobert   };
156*12c85518Srobert 
157*12c85518Srobert   /// Common type of ASTMatchers used for discovering gadgets.
158*12c85518Srobert   /// Useful for implementing the static matcher() methods
159*12c85518Srobert   /// that are expected from all non-abstract subclasses.
160*12c85518Srobert   using Matcher = decltype(stmt());
161*12c85518Srobert 
Gadget(Kind K)162*12c85518Srobert   Gadget(Kind K) : K(K) {}
163*12c85518Srobert 
getKind() const164*12c85518Srobert   Kind getKind() const { return K; }
165*12c85518Srobert 
166*12c85518Srobert   virtual bool isWarningGadget() const = 0;
167*12c85518Srobert   virtual const Stmt *getBaseStmt() const = 0;
168*12c85518Srobert 
169*12c85518Srobert   /// Returns the list of pointer-type variables on which this gadget performs
170*12c85518Srobert   /// its operation. Typically, there's only one variable. This isn't a list
171*12c85518Srobert   /// of all DeclRefExprs in the gadget's AST!
172*12c85518Srobert   virtual DeclUseList getClaimedVarUseSites() const = 0;
173*12c85518Srobert 
174*12c85518Srobert   virtual ~Gadget() = default;
175*12c85518Srobert 
176*12c85518Srobert private:
177*12c85518Srobert   Kind K;
178*12c85518Srobert };
179*12c85518Srobert 
180*12c85518Srobert 
181*12c85518Srobert /// Warning gadgets correspond to unsafe code patterns that warrants
182*12c85518Srobert /// an immediate warning.
183*12c85518Srobert class WarningGadget : public Gadget {
184*12c85518Srobert public:
WarningGadget(Kind K)185*12c85518Srobert   WarningGadget(Kind K) : Gadget(K) {}
186*12c85518Srobert 
classof(const Gadget * G)187*12c85518Srobert   static bool classof(const Gadget *G) { return G->isWarningGadget(); }
isWarningGadget() const188*12c85518Srobert   bool isWarningGadget() const final { return true; }
189*12c85518Srobert };
190*12c85518Srobert 
191*12c85518Srobert /// Fixable gadgets correspond to code patterns that aren't always unsafe but need to be
192*12c85518Srobert /// properly recognized in order to emit fixes. For example, if a raw pointer-type
193*12c85518Srobert /// variable is replaced by a safe C++ container, every use of such variable must be
194*12c85518Srobert /// carefully considered and possibly updated.
195*12c85518Srobert class FixableGadget : public Gadget {
196*12c85518Srobert public:
FixableGadget(Kind K)197*12c85518Srobert   FixableGadget(Kind K) : Gadget(K) {}
198*12c85518Srobert 
classof(const Gadget * G)199*12c85518Srobert   static bool classof(const Gadget *G) { return !G->isWarningGadget(); }
isWarningGadget() const200*12c85518Srobert   bool isWarningGadget() const final { return false; }
201*12c85518Srobert 
202*12c85518Srobert   /// Returns a fixit that would fix the current gadget according to
203*12c85518Srobert   /// the current strategy. Returns None if the fix cannot be produced;
204*12c85518Srobert   /// returns an empty list if no fixes are necessary.
getFixits(const Strategy &) const205*12c85518Srobert   virtual std::optional<FixItList> getFixits(const Strategy &) const {
206*12c85518Srobert     return std::nullopt;
207*12c85518Srobert   }
208*12c85518Srobert };
209*12c85518Srobert 
210*12c85518Srobert using FixableGadgetList = std::vector<std::unique_ptr<FixableGadget>>;
211*12c85518Srobert using WarningGadgetList = std::vector<std::unique_ptr<WarningGadget>>;
212*12c85518Srobert 
213*12c85518Srobert /// An increment of a pointer-type value is unsafe as it may run the pointer
214*12c85518Srobert /// out of bounds.
215*12c85518Srobert class IncrementGadget : public WarningGadget {
216*12c85518Srobert   static constexpr const char *const OpTag = "op";
217*12c85518Srobert   const UnaryOperator *Op;
218*12c85518Srobert 
219*12c85518Srobert public:
IncrementGadget(const MatchFinder::MatchResult & Result)220*12c85518Srobert   IncrementGadget(const MatchFinder::MatchResult &Result)
221*12c85518Srobert       : WarningGadget(Kind::Increment),
222*12c85518Srobert         Op(Result.Nodes.getNodeAs<UnaryOperator>(OpTag)) {}
223*12c85518Srobert 
classof(const Gadget * G)224*12c85518Srobert   static bool classof(const Gadget *G) {
225*12c85518Srobert     return G->getKind() == Kind::Increment;
226*12c85518Srobert   }
227*12c85518Srobert 
matcher()228*12c85518Srobert   static Matcher matcher() {
229*12c85518Srobert     return stmt(unaryOperator(
230*12c85518Srobert       hasOperatorName("++"),
231*12c85518Srobert       hasUnaryOperand(ignoringParenImpCasts(hasPointerType()))
232*12c85518Srobert     ).bind(OpTag));
233*12c85518Srobert   }
234*12c85518Srobert 
getBaseStmt() const235*12c85518Srobert   const UnaryOperator *getBaseStmt() const override { return Op; }
236*12c85518Srobert 
getClaimedVarUseSites() const237*12c85518Srobert   DeclUseList getClaimedVarUseSites() const override {
238*12c85518Srobert     SmallVector<const DeclRefExpr *, 2> Uses;
239*12c85518Srobert     if (const auto *DRE =
240*12c85518Srobert             dyn_cast<DeclRefExpr>(Op->getSubExpr()->IgnoreParenImpCasts())) {
241*12c85518Srobert       Uses.push_back(DRE);
242*12c85518Srobert     }
243*12c85518Srobert 
244*12c85518Srobert     return std::move(Uses);
245*12c85518Srobert   }
246*12c85518Srobert };
247*12c85518Srobert 
248*12c85518Srobert /// A decrement of a pointer-type value is unsafe as it may run the pointer
249*12c85518Srobert /// out of bounds.
250*12c85518Srobert class DecrementGadget : public WarningGadget {
251*12c85518Srobert   static constexpr const char *const OpTag = "op";
252*12c85518Srobert   const UnaryOperator *Op;
253*12c85518Srobert 
254*12c85518Srobert public:
DecrementGadget(const MatchFinder::MatchResult & Result)255*12c85518Srobert   DecrementGadget(const MatchFinder::MatchResult &Result)
256*12c85518Srobert       : WarningGadget(Kind::Decrement),
257*12c85518Srobert         Op(Result.Nodes.getNodeAs<UnaryOperator>(OpTag)) {}
258*12c85518Srobert 
classof(const Gadget * G)259*12c85518Srobert   static bool classof(const Gadget *G) {
260*12c85518Srobert     return G->getKind() == Kind::Decrement;
261*12c85518Srobert   }
262*12c85518Srobert 
matcher()263*12c85518Srobert   static Matcher matcher() {
264*12c85518Srobert     return stmt(unaryOperator(
265*12c85518Srobert       hasOperatorName("--"),
266*12c85518Srobert       hasUnaryOperand(ignoringParenImpCasts(hasPointerType()))
267*12c85518Srobert     ).bind(OpTag));
268*12c85518Srobert   }
269*12c85518Srobert 
getBaseStmt() const270*12c85518Srobert   const UnaryOperator *getBaseStmt() const override { return Op; }
271*12c85518Srobert 
getClaimedVarUseSites() const272*12c85518Srobert   DeclUseList getClaimedVarUseSites() const override {
273*12c85518Srobert     if (const auto *DRE =
274*12c85518Srobert             dyn_cast<DeclRefExpr>(Op->getSubExpr()->IgnoreParenImpCasts())) {
275*12c85518Srobert       return {DRE};
276*12c85518Srobert     }
277*12c85518Srobert 
278*12c85518Srobert     return {};
279*12c85518Srobert   }
280*12c85518Srobert };
281*12c85518Srobert 
282*12c85518Srobert /// Array subscript expressions on raw pointers as if they're arrays. Unsafe as
283*12c85518Srobert /// it doesn't have any bounds checks for the array.
284*12c85518Srobert class ArraySubscriptGadget : public WarningGadget {
285*12c85518Srobert   static constexpr const char *const ArraySubscrTag = "arraySubscr";
286*12c85518Srobert   const ArraySubscriptExpr *ASE;
287*12c85518Srobert 
288*12c85518Srobert public:
ArraySubscriptGadget(const MatchFinder::MatchResult & Result)289*12c85518Srobert   ArraySubscriptGadget(const MatchFinder::MatchResult &Result)
290*12c85518Srobert       : WarningGadget(Kind::ArraySubscript),
291*12c85518Srobert         ASE(Result.Nodes.getNodeAs<ArraySubscriptExpr>(ArraySubscrTag)) {}
292*12c85518Srobert 
classof(const Gadget * G)293*12c85518Srobert   static bool classof(const Gadget *G) {
294*12c85518Srobert     return G->getKind() == Kind::ArraySubscript;
295*12c85518Srobert   }
296*12c85518Srobert 
matcher()297*12c85518Srobert   static Matcher matcher() {
298*12c85518Srobert     // FIXME: What if the index is integer literal 0? Should this be
299*12c85518Srobert     // a safe gadget in this case?
300*12c85518Srobert       // clang-format off
301*12c85518Srobert       return stmt(arraySubscriptExpr(
302*12c85518Srobert             hasBase(ignoringParenImpCasts(
303*12c85518Srobert               anyOf(hasPointerType(), hasArrayType()))),
304*12c85518Srobert             unless(hasIndex(integerLiteral(equals(0)))))
305*12c85518Srobert             .bind(ArraySubscrTag));
306*12c85518Srobert       // clang-format on
307*12c85518Srobert   }
308*12c85518Srobert 
getBaseStmt() const309*12c85518Srobert   const ArraySubscriptExpr *getBaseStmt() const override { return ASE; }
310*12c85518Srobert 
getClaimedVarUseSites() const311*12c85518Srobert   DeclUseList getClaimedVarUseSites() const override {
312*12c85518Srobert     if (const auto *DRE =
313*12c85518Srobert             dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts())) {
314*12c85518Srobert       return {DRE};
315*12c85518Srobert     }
316*12c85518Srobert 
317*12c85518Srobert     return {};
318*12c85518Srobert   }
319*12c85518Srobert };
320*12c85518Srobert 
321*12c85518Srobert /// A pointer arithmetic expression of one of the forms:
322*12c85518Srobert ///  \code
323*12c85518Srobert ///  ptr + n | n + ptr | ptr - n | ptr += n | ptr -= n
324*12c85518Srobert ///  \endcode
325*12c85518Srobert class PointerArithmeticGadget : public WarningGadget {
326*12c85518Srobert   static constexpr const char *const PointerArithmeticTag = "ptrAdd";
327*12c85518Srobert   static constexpr const char *const PointerArithmeticPointerTag = "ptrAddPtr";
328*12c85518Srobert   const BinaryOperator *PA; // pointer arithmetic expression
329*12c85518Srobert   const Expr * Ptr;         // the pointer expression in `PA`
330*12c85518Srobert 
331*12c85518Srobert public:
PointerArithmeticGadget(const MatchFinder::MatchResult & Result)332*12c85518Srobert     PointerArithmeticGadget(const MatchFinder::MatchResult &Result)
333*12c85518Srobert       : WarningGadget(Kind::PointerArithmetic),
334*12c85518Srobert         PA(Result.Nodes.getNodeAs<BinaryOperator>(PointerArithmeticTag)),
335*12c85518Srobert         Ptr(Result.Nodes.getNodeAs<Expr>(PointerArithmeticPointerTag)) {}
336*12c85518Srobert 
classof(const Gadget * G)337*12c85518Srobert   static bool classof(const Gadget *G) {
338*12c85518Srobert     return G->getKind() == Kind::PointerArithmetic;
339*12c85518Srobert   }
340*12c85518Srobert 
matcher()341*12c85518Srobert   static Matcher matcher() {
342*12c85518Srobert     auto HasIntegerType = anyOf(
343*12c85518Srobert           hasType(isInteger()), hasType(enumType()));
344*12c85518Srobert     auto PtrAtRight = allOf(hasOperatorName("+"),
345*12c85518Srobert                             hasRHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)),
346*12c85518Srobert                             hasLHS(HasIntegerType));
347*12c85518Srobert     auto PtrAtLeft = allOf(
348*12c85518Srobert            anyOf(hasOperatorName("+"), hasOperatorName("-"),
349*12c85518Srobert                  hasOperatorName("+="), hasOperatorName("-=")),
350*12c85518Srobert            hasLHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)),
351*12c85518Srobert            hasRHS(HasIntegerType));
352*12c85518Srobert 
353*12c85518Srobert     return stmt(binaryOperator(anyOf(PtrAtLeft, PtrAtRight)).bind(PointerArithmeticTag));
354*12c85518Srobert   }
355*12c85518Srobert 
getBaseStmt() const356*12c85518Srobert   const Stmt *getBaseStmt() const override { return PA; }
357*12c85518Srobert 
getClaimedVarUseSites() const358*12c85518Srobert   DeclUseList getClaimedVarUseSites() const override {
359*12c85518Srobert     if (const auto *DRE =
360*12c85518Srobert             dyn_cast<DeclRefExpr>(Ptr->IgnoreParenImpCasts())) {
361*12c85518Srobert       return {DRE};
362*12c85518Srobert     }
363*12c85518Srobert 
364*12c85518Srobert     return {};
365*12c85518Srobert   }
366*12c85518Srobert   // FIXME: pointer adding zero should be fine
367*12c85518Srobert   //FIXME: this gadge will need a fix-it
368*12c85518Srobert };
369*12c85518Srobert } // namespace
370*12c85518Srobert 
371*12c85518Srobert namespace {
372*12c85518Srobert // An auxiliary tracking facility for the fixit analysis. It helps connect
373*12c85518Srobert // declarations to its and make sure we've covered all uses with our analysis
374*12c85518Srobert // before we try to fix the declaration.
375*12c85518Srobert class DeclUseTracker {
376*12c85518Srobert   using UseSetTy = SmallSet<const DeclRefExpr *, 16>;
377*12c85518Srobert   using DefMapTy = DenseMap<const VarDecl *, const DeclStmt *>;
378*12c85518Srobert 
379*12c85518Srobert   // Allocate on the heap for easier move.
380*12c85518Srobert   std::unique_ptr<UseSetTy> Uses{std::make_unique<UseSetTy>()};
381*12c85518Srobert   DefMapTy Defs{};
382*12c85518Srobert 
383*12c85518Srobert public:
384*12c85518Srobert   DeclUseTracker() = default;
385*12c85518Srobert   DeclUseTracker(const DeclUseTracker &) = delete; // Let's avoid copies.
386*12c85518Srobert   DeclUseTracker(DeclUseTracker &&) = default;
387*12c85518Srobert   DeclUseTracker &operator=(DeclUseTracker &&) = default;
388*12c85518Srobert 
389*12c85518Srobert   // Start tracking a freshly discovered DRE.
discoverUse(const DeclRefExpr * DRE)390*12c85518Srobert   void discoverUse(const DeclRefExpr *DRE) { Uses->insert(DRE); }
391*12c85518Srobert 
392*12c85518Srobert   // Stop tracking the DRE as it's been fully figured out.
claimUse(const DeclRefExpr * DRE)393*12c85518Srobert   void claimUse(const DeclRefExpr *DRE) {
394*12c85518Srobert     assert(Uses->count(DRE) &&
395*12c85518Srobert            "DRE not found or claimed by multiple matchers!");
396*12c85518Srobert     Uses->erase(DRE);
397*12c85518Srobert   }
398*12c85518Srobert 
399*12c85518Srobert   // A variable is unclaimed if at least one use is unclaimed.
hasUnclaimedUses(const VarDecl * VD) const400*12c85518Srobert   bool hasUnclaimedUses(const VarDecl *VD) const {
401*12c85518Srobert     // FIXME: Can this be less linear? Maybe maintain a map from VDs to DREs?
402*12c85518Srobert     return any_of(*Uses, [VD](const DeclRefExpr *DRE) {
403*12c85518Srobert       return DRE->getDecl()->getCanonicalDecl() == VD->getCanonicalDecl();
404*12c85518Srobert     });
405*12c85518Srobert   }
406*12c85518Srobert 
discoverDecl(const DeclStmt * DS)407*12c85518Srobert   void discoverDecl(const DeclStmt *DS) {
408*12c85518Srobert     for (const Decl *D : DS->decls()) {
409*12c85518Srobert       if (const auto *VD = dyn_cast<VarDecl>(D)) {
410*12c85518Srobert         // FIXME: Assertion temporarily disabled due to a bug in
411*12c85518Srobert         // ASTMatcher internal behavior in presence of GNU
412*12c85518Srobert         // statement-expressions. We need to properly investigate this
413*12c85518Srobert         // because it can screw up our algorithm in other ways.
414*12c85518Srobert         // assert(Defs.count(VD) == 0 && "Definition already discovered!");
415*12c85518Srobert         Defs[VD] = DS;
416*12c85518Srobert       }
417*12c85518Srobert     }
418*12c85518Srobert   }
419*12c85518Srobert 
lookupDecl(const VarDecl * VD) const420*12c85518Srobert   const DeclStmt *lookupDecl(const VarDecl *VD) const {
421*12c85518Srobert     auto It = Defs.find(VD);
422*12c85518Srobert     assert(It != Defs.end() && "Definition never discovered!");
423*12c85518Srobert     return It->second;
424*12c85518Srobert   }
425*12c85518Srobert };
426*12c85518Srobert } // namespace
427*12c85518Srobert 
428*12c85518Srobert namespace {
429*12c85518Srobert // Strategy is a map from variables to the way we plan to emit fixes for
430*12c85518Srobert // these variables. It is figured out gradually by trying different fixes
431*12c85518Srobert // for different variables depending on gadgets in which these variables
432*12c85518Srobert // participate.
433*12c85518Srobert class Strategy {
434*12c85518Srobert public:
435*12c85518Srobert   enum class Kind {
436*12c85518Srobert     Wontfix,    // We don't plan to emit a fixit for this variable.
437*12c85518Srobert     Span,       // We recommend replacing the variable with std::span.
438*12c85518Srobert     Iterator,   // We recommend replacing the variable with std::span::iterator.
439*12c85518Srobert     Array,      // We recommend replacing the variable with std::array.
440*12c85518Srobert     Vector      // We recommend replacing the variable with std::vector.
441*12c85518Srobert   };
442*12c85518Srobert 
443*12c85518Srobert private:
444*12c85518Srobert   using MapTy = llvm::DenseMap<const VarDecl *, Kind>;
445*12c85518Srobert 
446*12c85518Srobert   MapTy Map;
447*12c85518Srobert 
448*12c85518Srobert public:
449*12c85518Srobert   Strategy() = default;
450*12c85518Srobert   Strategy(const Strategy &) = delete; // Let's avoid copies.
451*12c85518Srobert   Strategy(Strategy &&) = default;
452*12c85518Srobert 
set(const VarDecl * VD,Kind K)453*12c85518Srobert   void set(const VarDecl *VD, Kind K) {
454*12c85518Srobert     Map[VD] = K;
455*12c85518Srobert   }
456*12c85518Srobert 
lookup(const VarDecl * VD) const457*12c85518Srobert   Kind lookup(const VarDecl *VD) const {
458*12c85518Srobert     auto I = Map.find(VD);
459*12c85518Srobert     if (I == Map.end())
460*12c85518Srobert       return Kind::Wontfix;
461*12c85518Srobert 
462*12c85518Srobert     return I->second;
463*12c85518Srobert   }
464*12c85518Srobert };
465*12c85518Srobert } // namespace
466*12c85518Srobert 
467*12c85518Srobert /// Scan the function and return a list of gadgets found with provided kits.
findGadgets(const Decl * D)468*12c85518Srobert static std::tuple<FixableGadgetList, WarningGadgetList, DeclUseTracker> findGadgets(const Decl *D) {
469*12c85518Srobert 
470*12c85518Srobert   struct GadgetFinderCallback : MatchFinder::MatchCallback {
471*12c85518Srobert     FixableGadgetList FixableGadgets;
472*12c85518Srobert     WarningGadgetList WarningGadgets;
473*12c85518Srobert     DeclUseTracker Tracker;
474*12c85518Srobert 
475*12c85518Srobert     void run(const MatchFinder::MatchResult &Result) override {
476*12c85518Srobert       // In debug mode, assert that we've found exactly one gadget.
477*12c85518Srobert       // This helps us avoid conflicts in .bind() tags.
478*12c85518Srobert #if NDEBUG
479*12c85518Srobert #define NEXT return
480*12c85518Srobert #else
481*12c85518Srobert       [[maybe_unused]] int numFound = 0;
482*12c85518Srobert #define NEXT ++numFound
483*12c85518Srobert #endif
484*12c85518Srobert 
485*12c85518Srobert       if (const auto *DRE = Result.Nodes.getNodeAs<DeclRefExpr>("any_dre")) {
486*12c85518Srobert         Tracker.discoverUse(DRE);
487*12c85518Srobert         NEXT;
488*12c85518Srobert       }
489*12c85518Srobert 
490*12c85518Srobert       if (const auto *DS = Result.Nodes.getNodeAs<DeclStmt>("any_ds")) {
491*12c85518Srobert         Tracker.discoverDecl(DS);
492*12c85518Srobert         NEXT;
493*12c85518Srobert       }
494*12c85518Srobert 
495*12c85518Srobert       // Figure out which matcher we've found, and call the appropriate
496*12c85518Srobert       // subclass constructor.
497*12c85518Srobert       // FIXME: Can we do this more logarithmically?
498*12c85518Srobert #define FIXABLE_GADGET(name)                                                           \
499*12c85518Srobert       if (Result.Nodes.getNodeAs<Stmt>(#name)) {                               \
500*12c85518Srobert         FixableGadgets.push_back(std::make_unique<name ## Gadget>(Result));           \
501*12c85518Srobert         NEXT;                                                                  \
502*12c85518Srobert       }
503*12c85518Srobert #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
504*12c85518Srobert #define WARNING_GADGET(name)                                                           \
505*12c85518Srobert       if (Result.Nodes.getNodeAs<Stmt>(#name)) {                               \
506*12c85518Srobert         WarningGadgets.push_back(std::make_unique<name ## Gadget>(Result));           \
507*12c85518Srobert         NEXT;                                                                  \
508*12c85518Srobert       }
509*12c85518Srobert #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
510*12c85518Srobert 
511*12c85518Srobert       assert(numFound >= 1 && "Gadgets not found in match result!");
512*12c85518Srobert       assert(numFound <= 1 && "Conflicting bind tags in gadgets!");
513*12c85518Srobert     }
514*12c85518Srobert   };
515*12c85518Srobert 
516*12c85518Srobert   MatchFinder M;
517*12c85518Srobert   GadgetFinderCallback CB;
518*12c85518Srobert 
519*12c85518Srobert   // clang-format off
520*12c85518Srobert   M.addMatcher(
521*12c85518Srobert     stmt(forEveryDescendant(
522*12c85518Srobert       stmt(anyOf(
523*12c85518Srobert         // Add Gadget::matcher() for every gadget in the registry.
524*12c85518Srobert #define GADGET(x)                                                              \
525*12c85518Srobert         x ## Gadget::matcher().bind(#x),
526*12c85518Srobert #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
527*12c85518Srobert         // In parallel, match all DeclRefExprs so that to find out
528*12c85518Srobert         // whether there are any uncovered by gadgets.
529*12c85518Srobert         declRefExpr(anyOf(hasPointerType(), hasArrayType()),
530*12c85518Srobert                     to(varDecl())).bind("any_dre"),
531*12c85518Srobert         // Also match DeclStmts because we'll need them when fixing
532*12c85518Srobert         // their underlying VarDecls that otherwise don't have
533*12c85518Srobert         // any backreferences to DeclStmts.
534*12c85518Srobert         declStmt().bind("any_ds")
535*12c85518Srobert       ))
536*12c85518Srobert       // FIXME: Idiomatically there should be a forCallable(equalsNode(D))
537*12c85518Srobert       // here, to make sure that the statement actually belongs to the
538*12c85518Srobert       // function and not to a nested function. However, forCallable uses
539*12c85518Srobert       // ParentMap which can't be used before the AST is fully constructed.
540*12c85518Srobert       // The original problem doesn't sound like it needs ParentMap though,
541*12c85518Srobert       // maybe there's a more direct solution?
542*12c85518Srobert     )),
543*12c85518Srobert     &CB
544*12c85518Srobert   );
545*12c85518Srobert   // clang-format on
546*12c85518Srobert 
547*12c85518Srobert   M.match(*D->getBody(), D->getASTContext());
548*12c85518Srobert 
549*12c85518Srobert   // Gadgets "claim" variables they're responsible for. Once this loop finishes,
550*12c85518Srobert   // the tracker will only track DREs that weren't claimed by any gadgets,
551*12c85518Srobert   // i.e. not understood by the analysis.
552*12c85518Srobert   for (const auto &G : CB.FixableGadgets) {
553*12c85518Srobert     for (const auto *DRE : G->getClaimedVarUseSites()) {
554*12c85518Srobert       CB.Tracker.claimUse(DRE);
555*12c85518Srobert     }
556*12c85518Srobert   }
557*12c85518Srobert 
558*12c85518Srobert   return {std::move(CB.FixableGadgets), std::move(CB.WarningGadgets), std::move(CB.Tracker)};
559*12c85518Srobert }
560*12c85518Srobert 
561*12c85518Srobert struct WarningGadgetSets {
562*12c85518Srobert   std::map<const VarDecl *, std::set<std::unique_ptr<WarningGadget>>> byVar;
563*12c85518Srobert   // These Gadgets are not related to pointer variables (e. g. temporaries).
564*12c85518Srobert   llvm::SmallVector<std::unique_ptr<WarningGadget>, 16> noVar;
565*12c85518Srobert };
566*12c85518Srobert 
567*12c85518Srobert static WarningGadgetSets
groupWarningGadgetsByVar(WarningGadgetList && AllUnsafeOperations)568*12c85518Srobert groupWarningGadgetsByVar(WarningGadgetList &&AllUnsafeOperations) {
569*12c85518Srobert   WarningGadgetSets result;
570*12c85518Srobert   // If some gadgets cover more than one
571*12c85518Srobert   // variable, they'll appear more than once in the map.
572*12c85518Srobert   for (auto &G : AllUnsafeOperations) {
573*12c85518Srobert     DeclUseList ClaimedVarUseSites = G->getClaimedVarUseSites();
574*12c85518Srobert 
575*12c85518Srobert     bool AssociatedWithVarDecl = false;
576*12c85518Srobert     for (const DeclRefExpr *DRE : ClaimedVarUseSites) {
577*12c85518Srobert       if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
578*12c85518Srobert         result.byVar[VD].emplace(std::move(G));
579*12c85518Srobert         AssociatedWithVarDecl = true;
580*12c85518Srobert       }
581*12c85518Srobert     }
582*12c85518Srobert 
583*12c85518Srobert     if (!AssociatedWithVarDecl) {
584*12c85518Srobert       result.noVar.emplace_back(std::move(G));
585*12c85518Srobert       continue;
586*12c85518Srobert     }
587*12c85518Srobert   }
588*12c85518Srobert   return result;
589*12c85518Srobert }
590*12c85518Srobert 
591*12c85518Srobert struct FixableGadgetSets {
592*12c85518Srobert   std::map<const VarDecl *, std::set<std::unique_ptr<FixableGadget>>> byVar;
593*12c85518Srobert };
594*12c85518Srobert 
595*12c85518Srobert static FixableGadgetSets
groupFixablesByVar(FixableGadgetList && AllFixableOperations)596*12c85518Srobert groupFixablesByVar(FixableGadgetList &&AllFixableOperations) {
597*12c85518Srobert   FixableGadgetSets FixablesForUnsafeVars;
598*12c85518Srobert   for (auto &F : AllFixableOperations) {
599*12c85518Srobert     DeclUseList DREs = F->getClaimedVarUseSites();
600*12c85518Srobert 
601*12c85518Srobert     for (const DeclRefExpr *DRE : DREs) {
602*12c85518Srobert       if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
603*12c85518Srobert         FixablesForUnsafeVars.byVar[VD].emplace(std::move(F));
604*12c85518Srobert       }
605*12c85518Srobert     }
606*12c85518Srobert   }
607*12c85518Srobert   return FixablesForUnsafeVars;
608*12c85518Srobert }
609*12c85518Srobert 
610*12c85518Srobert static std::map<const VarDecl *, FixItList>
getFixIts(FixableGadgetSets & FixablesForUnsafeVars,const Strategy & S)611*12c85518Srobert getFixIts(FixableGadgetSets &FixablesForUnsafeVars, const Strategy &S) {
612*12c85518Srobert   std::map<const VarDecl *, FixItList> FixItsForVariable;
613*12c85518Srobert   for (const auto &[VD, Fixables] : FixablesForUnsafeVars.byVar) {
614*12c85518Srobert     // TODO fixVariable - fixit for the variable itself
615*12c85518Srobert     bool ImpossibleToFix = false;
616*12c85518Srobert     llvm::SmallVector<FixItHint, 16> FixItsForVD;
617*12c85518Srobert     for (const auto &F : Fixables) {
618*12c85518Srobert       llvm::Optional<FixItList> Fixits = F->getFixits(S);
619*12c85518Srobert       if (!Fixits) {
620*12c85518Srobert         ImpossibleToFix = true;
621*12c85518Srobert         break;
622*12c85518Srobert       } else {
623*12c85518Srobert         const FixItList CorrectFixes = Fixits.value();
624*12c85518Srobert         FixItsForVD.insert(FixItsForVD.end(), CorrectFixes.begin(),
625*12c85518Srobert                            CorrectFixes.end());
626*12c85518Srobert       }
627*12c85518Srobert     }
628*12c85518Srobert     if (ImpossibleToFix)
629*12c85518Srobert       FixItsForVariable.erase(VD);
630*12c85518Srobert     else
631*12c85518Srobert       FixItsForVariable[VD].insert(FixItsForVariable[VD].end(),
632*12c85518Srobert                                    FixItsForVD.begin(), FixItsForVD.end());
633*12c85518Srobert   }
634*12c85518Srobert   return FixItsForVariable;
635*12c85518Srobert }
636*12c85518Srobert 
637*12c85518Srobert static Strategy
getNaiveStrategy(const llvm::SmallVectorImpl<const VarDecl * > & UnsafeVars)638*12c85518Srobert getNaiveStrategy(const llvm::SmallVectorImpl<const VarDecl *> &UnsafeVars) {
639*12c85518Srobert   Strategy S;
640*12c85518Srobert   for (const VarDecl *VD : UnsafeVars) {
641*12c85518Srobert     S.set(VD, Strategy::Kind::Span);
642*12c85518Srobert   }
643*12c85518Srobert   return S;
644*12c85518Srobert }
645*12c85518Srobert 
checkUnsafeBufferUsage(const Decl * D,UnsafeBufferUsageHandler & Handler)646*12c85518Srobert void clang::checkUnsafeBufferUsage(const Decl *D,
647*12c85518Srobert                                    UnsafeBufferUsageHandler &Handler) {
648*12c85518Srobert   assert(D && D->getBody());
649*12c85518Srobert 
650*12c85518Srobert   WarningGadgetSets UnsafeOps;
651*12c85518Srobert   FixableGadgetSets FixablesForUnsafeVars;
652*12c85518Srobert   DeclUseTracker Tracker;
653*12c85518Srobert 
654*12c85518Srobert   {
655*12c85518Srobert     auto [FixableGadgets, WarningGadgets, TrackerRes] = findGadgets(D);
656*12c85518Srobert     UnsafeOps = groupWarningGadgetsByVar(std::move(WarningGadgets));
657*12c85518Srobert     FixablesForUnsafeVars = groupFixablesByVar(std::move(FixableGadgets));
658*12c85518Srobert     Tracker = std::move(TrackerRes);
659*12c85518Srobert   }
660*12c85518Srobert 
661*12c85518Srobert   // Filter out non-local vars and vars with unclaimed DeclRefExpr-s.
662*12c85518Srobert   for (auto it = FixablesForUnsafeVars.byVar.cbegin();
663*12c85518Srobert        it != FixablesForUnsafeVars.byVar.cend();) {
664*12c85518Srobert     // FIXME: Support ParmVarDecl as well.
665*12c85518Srobert     if (!it->first->isLocalVarDecl() || Tracker.hasUnclaimedUses(it->first)) {
666*12c85518Srobert       it = FixablesForUnsafeVars.byVar.erase(it);
667*12c85518Srobert     } else {
668*12c85518Srobert       ++it;
669*12c85518Srobert     }
670*12c85518Srobert   }
671*12c85518Srobert 
672*12c85518Srobert   llvm::SmallVector<const VarDecl *, 16> UnsafeVars;
673*12c85518Srobert   for (const auto &[VD, ignore] : FixablesForUnsafeVars.byVar)
674*12c85518Srobert     UnsafeVars.push_back(VD);
675*12c85518Srobert 
676*12c85518Srobert   Strategy NaiveStrategy = getNaiveStrategy(UnsafeVars);
677*12c85518Srobert   std::map<const VarDecl *, FixItList> FixItsForVariable =
678*12c85518Srobert       getFixIts(FixablesForUnsafeVars, NaiveStrategy);
679*12c85518Srobert 
680*12c85518Srobert   // FIXME Detect overlapping FixIts.
681*12c85518Srobert 
682*12c85518Srobert   for (const auto &G : UnsafeOps.noVar) {
683*12c85518Srobert     Handler.handleUnsafeOperation(G->getBaseStmt(), /*IsRelatedToDecl=*/false);
684*12c85518Srobert   }
685*12c85518Srobert 
686*12c85518Srobert   for (const auto &[VD, WarningGadgets] : UnsafeOps.byVar) {
687*12c85518Srobert     auto FixItsIt = FixItsForVariable.find(VD);
688*12c85518Srobert     Handler.handleFixableVariable(VD, FixItsIt != FixItsForVariable.end()
689*12c85518Srobert                                           ? std::move(FixItsIt->second)
690*12c85518Srobert                                           : FixItList{});
691*12c85518Srobert     for (const auto &G : WarningGadgets) {
692*12c85518Srobert       Handler.handleUnsafeOperation(G->getBaseStmt(), /*IsRelatedToDecl=*/true);
693*12c85518Srobert     }
694*12c85518Srobert   }
695*12c85518Srobert }
696