xref: /llvm-project/clang-tools-extra/clang-tidy/bugprone/NotNullTerminatedResultCheck.cpp (revision 672207c319a06f20dc634bcd21678d5dbbe7a6b9)
1 //===--- NotNullTerminatedResultCheck.cpp - clang-tidy ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "NotNullTerminatedResultCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Frontend/CompilerInstance.h"
13 #include "clang/Lex/Lexer.h"
14 #include "clang/Lex/PPCallbacks.h"
15 
16 using namespace clang::ast_matchers;
17 
18 namespace clang {
19 namespace tidy {
20 namespace bugprone {
21 
22 constexpr llvm::StringLiteral FunctionExprName = "FunctionExpr";
23 constexpr llvm::StringLiteral CastExprName = "CastExpr";
24 constexpr llvm::StringLiteral UnknownDestName = "UnknownDest";
25 constexpr llvm::StringLiteral DestArrayTyName = "DestArrayTy";
26 constexpr llvm::StringLiteral DestVarDeclName = "DestVarDecl";
27 constexpr llvm::StringLiteral DestMallocExprName = "DestMalloc";
28 constexpr llvm::StringLiteral DestExprName = "DestExpr";
29 constexpr llvm::StringLiteral SrcVarDeclName = "SrcVarDecl";
30 constexpr llvm::StringLiteral SrcExprName = "SrcExpr";
31 constexpr llvm::StringLiteral LengthExprName = "LengthExpr";
32 constexpr llvm::StringLiteral WrongLengthExprName = "WrongLength";
33 constexpr llvm::StringLiteral UnknownLengthName = "UnknownLength";
34 
35 enum class LengthHandleKind { Increase, Decrease };
36 
37 namespace {
38 static Preprocessor *PP;
39 } // namespace
40 
41 // Returns the expression of destination's capacity which is part of a
42 // 'VariableArrayType', 'ConstantArrayTypeLoc' or an argument of a 'malloc()'
43 // family function call.
44 static const Expr *getDestCapacityExpr(const MatchFinder::MatchResult &Result) {
45   if (const auto *DestMalloc = Result.Nodes.getNodeAs<Expr>(DestMallocExprName))
46     return DestMalloc;
47 
48   if (const auto *DestVAT =
49           Result.Nodes.getNodeAs<VariableArrayType>(DestArrayTyName))
50     return DestVAT->getSizeExpr();
51 
52   if (const auto *DestVD = Result.Nodes.getNodeAs<VarDecl>(DestVarDeclName))
53     if (const TypeLoc DestTL = DestVD->getTypeSourceInfo()->getTypeLoc())
54       if (const auto DestCTL = DestTL.getAs<ConstantArrayTypeLoc>())
55         return DestCTL.getSizeExpr();
56 
57   return nullptr;
58 }
59 
60 // Returns the length of \p E as an 'IntegerLiteral' or a 'StringLiteral'
61 // without the null-terminator.
62 static unsigned getLength(const Expr *E,
63                           const MatchFinder::MatchResult &Result) {
64   if (!E)
65     return 0;
66 
67   Expr::EvalResult Length;
68   E = E->IgnoreImpCasts();
69 
70   if (const auto *LengthDRE = dyn_cast<DeclRefExpr>(E))
71     if (const auto *LengthVD = dyn_cast<VarDecl>(LengthDRE->getDecl()))
72       if (!isa<ParmVarDecl>(LengthVD))
73         if (const Expr *LengthInit = LengthVD->getInit())
74           if (LengthInit->EvaluateAsInt(Length, *Result.Context))
75             return Length.Val.getInt().getZExtValue();
76 
77   if (const auto *LengthIL = dyn_cast<IntegerLiteral>(E))
78     return LengthIL->getValue().getZExtValue();
79 
80   if (const auto *StrDRE = dyn_cast<DeclRefExpr>(E))
81     if (const auto *StrVD = dyn_cast<VarDecl>(StrDRE->getDecl()))
82       if (const Expr *StrInit = StrVD->getInit())
83         if (const auto *StrSL =
84                 dyn_cast<StringLiteral>(StrInit->IgnoreImpCasts()))
85           return StrSL->getLength();
86 
87   if (const auto *SrcSL = dyn_cast<StringLiteral>(E))
88     return SrcSL->getLength();
89 
90   return 0;
91 }
92 
93 // Returns the capacity of the destination array.
94 // For example in 'char dest[13]; memcpy(dest, ...)' it returns 13.
95 static int getDestCapacity(const MatchFinder::MatchResult &Result) {
96   if (const auto *DestCapacityExpr = getDestCapacityExpr(Result))
97     return getLength(DestCapacityExpr, Result);
98 
99   return 0;
100 }
101 
102 // Returns the 'strlen()' if it is the given length.
103 static const CallExpr *getStrlenExpr(const MatchFinder::MatchResult &Result) {
104   if (const auto *StrlenExpr =
105           Result.Nodes.getNodeAs<CallExpr>(WrongLengthExprName))
106     if (const Decl *D = StrlenExpr->getCalleeDecl())
107       if (const FunctionDecl *FD = D->getAsFunction())
108         if (const IdentifierInfo *II = FD->getIdentifier())
109           if (II->isStr("strlen") || II->isStr("wcslen"))
110             return StrlenExpr;
111 
112   return nullptr;
113 }
114 
115 // Returns the length which is given in the memory/string handler function.
116 // For example in 'memcpy(dest, "foobar", 3)' it returns 3.
117 static int getGivenLength(const MatchFinder::MatchResult &Result) {
118   if (Result.Nodes.getNodeAs<Expr>(UnknownLengthName))
119     return 0;
120 
121   if (int Length =
122           getLength(Result.Nodes.getNodeAs<Expr>(WrongLengthExprName), Result))
123     return Length;
124 
125   if (int Length =
126           getLength(Result.Nodes.getNodeAs<Expr>(LengthExprName), Result))
127     return Length;
128 
129   // Special case, for example 'strlen("foo")'.
130   if (const CallExpr *StrlenCE = getStrlenExpr(Result))
131     if (const Expr *Arg = StrlenCE->getArg(0)->IgnoreImpCasts())
132       if (int ArgLength = getLength(Arg, Result))
133         return ArgLength;
134 
135   return 0;
136 }
137 
138 // Returns a string representation of \p E.
139 static StringRef exprToStr(const Expr *E,
140                            const MatchFinder::MatchResult &Result) {
141   if (!E)
142     return "";
143 
144   return Lexer::getSourceText(
145       CharSourceRange::getTokenRange(E->getSourceRange()),
146       *Result.SourceManager, Result.Context->getLangOpts(), 0);
147 }
148 
149 // Returns the proper token based end location of \p E.
150 static SourceLocation exprLocEnd(const Expr *E,
151                                  const MatchFinder::MatchResult &Result) {
152   return Lexer::getLocForEndOfToken(E->getEndLoc(), 0, *Result.SourceManager,
153                                     Result.Context->getLangOpts());
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // Rewrite decision helper functions.
158 //===----------------------------------------------------------------------===//
159 
160 // Increment by integer '1' can result in overflow if it is the maximal value.
161 // After that it would be extended to 'size_t' and its value would be wrong,
162 // therefore we have to inject '+ 1UL' instead.
163 static bool isInjectUL(const MatchFinder::MatchResult &Result) {
164   return getGivenLength(Result) == std::numeric_limits<int>::max();
165 }
166 
167 // If the capacity of the destination array is unknown it is denoted as unknown.
168 static bool isKnownDest(const MatchFinder::MatchResult &Result) {
169   return !Result.Nodes.getNodeAs<Expr>(UnknownDestName);
170 }
171 
172 // True if the capacity of the destination array is based on the given length,
173 // therefore we assume that it cannot overflow (e.g. 'malloc(given_length + 1)'
174 static bool isDestBasedOnGivenLength(const MatchFinder::MatchResult &Result) {
175   StringRef DestCapacityExprStr =
176       exprToStr(getDestCapacityExpr(Result), Result).trim();
177   StringRef LengthExprStr =
178       exprToStr(Result.Nodes.getNodeAs<Expr>(LengthExprName), Result).trim();
179 
180   return DestCapacityExprStr != "" && LengthExprStr != "" &&
181          DestCapacityExprStr.contains(LengthExprStr);
182 }
183 
184 // Writing and reading from the same memory cannot remove the null-terminator.
185 static bool isDestAndSrcEquals(const MatchFinder::MatchResult &Result) {
186   if (const auto *DestDRE = Result.Nodes.getNodeAs<DeclRefExpr>(DestExprName))
187     if (const auto *SrcDRE = Result.Nodes.getNodeAs<DeclRefExpr>(SrcExprName))
188       return DestDRE->getDecl()->getCanonicalDecl() ==
189              SrcDRE->getDecl()->getCanonicalDecl();
190 
191   return false;
192 }
193 
194 // For example 'std::string str = "foo"; memcpy(dst, str.data(), str.length())'.
195 static bool isStringDataAndLength(const MatchFinder::MatchResult &Result) {
196   const auto *DestExpr =
197       Result.Nodes.getNodeAs<CXXMemberCallExpr>(DestExprName);
198   const auto *SrcExpr = Result.Nodes.getNodeAs<CXXMemberCallExpr>(SrcExprName);
199   const auto *LengthExpr =
200       Result.Nodes.getNodeAs<CXXMemberCallExpr>(WrongLengthExprName);
201 
202   StringRef DestStr = "", SrcStr = "", LengthStr = "";
203   if (DestExpr)
204     if (const CXXMethodDecl *DestMD = DestExpr->getMethodDecl())
205       DestStr = DestMD->getName();
206 
207   if (SrcExpr)
208     if (const CXXMethodDecl *SrcMD = SrcExpr->getMethodDecl())
209       SrcStr = SrcMD->getName();
210 
211   if (LengthExpr)
212     if (const CXXMethodDecl *LengthMD = LengthExpr->getMethodDecl())
213       LengthStr = LengthMD->getName();
214 
215   return (LengthStr == "length" || LengthStr == "size") &&
216          (SrcStr == "data" || DestStr == "data");
217 }
218 
219 static bool
220 isGivenLengthEqualToSrcLength(const MatchFinder::MatchResult &Result) {
221   if (Result.Nodes.getNodeAs<Expr>(UnknownLengthName))
222     return false;
223 
224   if (isStringDataAndLength(Result))
225     return true;
226 
227   int GivenLength = getGivenLength(Result);
228   int SrcLength = getLength(Result.Nodes.getNodeAs<Expr>(SrcExprName), Result);
229 
230   if (GivenLength != 0 && SrcLength != 0 && GivenLength == SrcLength)
231     return true;
232 
233   if (const auto *LengthExpr = Result.Nodes.getNodeAs<Expr>(LengthExprName))
234     if (dyn_cast<BinaryOperator>(LengthExpr->IgnoreParenImpCasts()))
235       return false;
236 
237   // Check the strlen()'s argument's 'VarDecl' is equal to the source 'VarDecl'.
238   if (const CallExpr *StrlenCE = getStrlenExpr(Result))
239     if (const auto *ArgDRE =
240             dyn_cast<DeclRefExpr>(StrlenCE->getArg(0)->IgnoreImpCasts()))
241       if (const auto *SrcVD = Result.Nodes.getNodeAs<VarDecl>(SrcVarDeclName))
242         return dyn_cast<VarDecl>(ArgDRE->getDecl()) == SrcVD;
243 
244   return false;
245 }
246 
247 static bool isCorrectGivenLength(const MatchFinder::MatchResult &Result) {
248   if (Result.Nodes.getNodeAs<Expr>(UnknownLengthName))
249     return false;
250 
251   return !isGivenLengthEqualToSrcLength(Result);
252 }
253 
254 // If we rewrite the function call we need to create extra space to hold the
255 // null terminator. The new necessary capacity overflows without that '+ 1'
256 // size and we need to correct the given capacity.
257 static bool isDestCapacityOverflows(const MatchFinder::MatchResult &Result) {
258   if (!isKnownDest(Result))
259     return true;
260 
261   const Expr *DestCapacityExpr = getDestCapacityExpr(Result);
262   int DestCapacity = getLength(DestCapacityExpr, Result);
263   int GivenLength = getGivenLength(Result);
264 
265   if (GivenLength != 0 && DestCapacity != 0)
266     return isGivenLengthEqualToSrcLength(Result) && DestCapacity == GivenLength;
267 
268   // Assume that the destination array's capacity cannot overflow if the
269   // expression of the memory allocation contains '+ 1'.
270   StringRef DestCapacityExprStr = exprToStr(DestCapacityExpr, Result);
271   if (DestCapacityExprStr.contains("+1") || DestCapacityExprStr.contains("+ 1"))
272     return false;
273 
274   return true;
275 }
276 
277 static bool
278 isFixedGivenLengthAndUnknownSrc(const MatchFinder::MatchResult &Result) {
279   if (Result.Nodes.getNodeAs<IntegerLiteral>(WrongLengthExprName))
280     return !getLength(Result.Nodes.getNodeAs<Expr>(SrcExprName), Result);
281 
282   return false;
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // Code injection functions.
287 //===----------------------------------------------------------------------===//
288 
289 // Increase or decrease \p LengthExpr by one.
290 static void lengthExprHandle(const Expr *LengthExpr,
291                              LengthHandleKind LengthHandle,
292                              const MatchFinder::MatchResult &Result,
293                              DiagnosticBuilder &Diag) {
294   LengthExpr = LengthExpr->IgnoreParenImpCasts();
295 
296   // See whether we work with a macro.
297   bool IsMacroDefinition = false;
298   StringRef LengthExprStr = exprToStr(LengthExpr, Result);
299   Preprocessor::macro_iterator It = PP->macro_begin();
300   while (It != PP->macro_end() && !IsMacroDefinition) {
301     if (It->first->getName() == LengthExprStr)
302       IsMacroDefinition = true;
303 
304     ++It;
305   }
306 
307   // Try to obtain an 'IntegerLiteral' and adjust it.
308   if (!IsMacroDefinition) {
309     if (const auto *LengthIL = dyn_cast<IntegerLiteral>(LengthExpr)) {
310       size_t NewLength = LengthIL->getValue().getZExtValue() +
311                          (LengthHandle == LengthHandleKind::Increase
312                               ? (isInjectUL(Result) ? 1UL : 1)
313                               : -1);
314 
315       const auto NewLengthFix = FixItHint::CreateReplacement(
316           LengthIL->getSourceRange(),
317           (Twine(NewLength) + (isInjectUL(Result) ? "UL" : "")).str());
318       Diag << NewLengthFix;
319       return;
320     }
321   }
322 
323   // Try to obtain and remove the '+ 1' string as a decrement fix.
324   const auto *BO = dyn_cast<BinaryOperator>(LengthExpr);
325   if (BO && BO->getOpcode() == BO_Add &&
326       LengthHandle == LengthHandleKind::Decrease) {
327     const Expr *LhsExpr = BO->getLHS()->IgnoreImpCasts();
328     const Expr *RhsExpr = BO->getRHS()->IgnoreImpCasts();
329 
330     if (const auto *LhsIL = dyn_cast<IntegerLiteral>(LhsExpr)) {
331       if (LhsIL->getValue().getZExtValue() == 1) {
332         Diag << FixItHint::CreateRemoval(
333             {LhsIL->getBeginLoc(),
334              RhsExpr->getBeginLoc().getLocWithOffset(-1)});
335         return;
336       }
337     }
338 
339     if (const auto *RhsIL = dyn_cast<IntegerLiteral>(RhsExpr)) {
340       if (RhsIL->getValue().getZExtValue() == 1) {
341         Diag << FixItHint::CreateRemoval(
342             {LhsExpr->getEndLoc().getLocWithOffset(1), RhsIL->getEndLoc()});
343         return;
344       }
345     }
346   }
347 
348   // Try to inject the '+ 1'/'- 1' string.
349   bool NeedInnerParen = BO && BO->getOpcode() != BO_Add;
350 
351   if (NeedInnerParen)
352     Diag << FixItHint::CreateInsertion(LengthExpr->getBeginLoc(), "(");
353 
354   SmallString<8> Injection;
355   if (NeedInnerParen)
356     Injection += ')';
357   Injection += LengthHandle == LengthHandleKind::Increase ? " + 1" : " - 1";
358   if (isInjectUL(Result))
359     Injection += "UL";
360 
361   Diag << FixItHint::CreateInsertion(exprLocEnd(LengthExpr, Result), Injection);
362 }
363 
364 static void lengthArgHandle(LengthHandleKind LengthHandle,
365                             const MatchFinder::MatchResult &Result,
366                             DiagnosticBuilder &Diag) {
367   const auto *LengthExpr = Result.Nodes.getNodeAs<Expr>(LengthExprName);
368   lengthExprHandle(LengthExpr, LengthHandle, Result, Diag);
369 }
370 
371 static void lengthArgPosHandle(unsigned ArgPos, LengthHandleKind LengthHandle,
372                                const MatchFinder::MatchResult &Result,
373                                DiagnosticBuilder &Diag) {
374   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
375   lengthExprHandle(FunctionExpr->getArg(ArgPos), LengthHandle, Result, Diag);
376 }
377 
378 // The string handler functions are only operates with plain 'char'/'wchar_t'
379 // without 'unsigned/signed', therefore we need to cast it.
380 static bool isDestExprFix(const MatchFinder::MatchResult &Result,
381                           DiagnosticBuilder &Diag) {
382   const auto *Dest = Result.Nodes.getNodeAs<Expr>(DestExprName);
383   if (!Dest)
384     return false;
385 
386   std::string TempTyStr = Dest->getType().getAsString();
387   StringRef TyStr = TempTyStr;
388   if (TyStr.startswith("char") || TyStr.startswith("wchar_t"))
389     return false;
390 
391   Diag << FixItHint::CreateInsertion(Dest->getBeginLoc(), "(char *)");
392   return true;
393 }
394 
395 // If the destination array is the same length as the given length we have to
396 // increase the capacity by one to create space for the null terminator.
397 static bool isDestCapacityFix(const MatchFinder::MatchResult &Result,
398                               DiagnosticBuilder &Diag) {
399   bool IsOverflows = isDestCapacityOverflows(Result);
400   if (IsOverflows)
401     if (const Expr *CapacityExpr = getDestCapacityExpr(Result))
402       lengthExprHandle(CapacityExpr, LengthHandleKind::Increase, Result, Diag);
403 
404   return IsOverflows;
405 }
406 
407 static void removeArg(int ArgPos, const MatchFinder::MatchResult &Result,
408                       DiagnosticBuilder &Diag) {
409   // This is the following structure: (src, '\0', strlen(src))
410   //                     ArgToRemove:             ~~~~~~~~~~~
411   //                          LHSArg:       ~~~~
412   //                    RemoveArgFix:           ~~~~~~~~~~~~~
413   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
414   const Expr *ArgToRemove = FunctionExpr->getArg(ArgPos);
415   const Expr *LHSArg = FunctionExpr->getArg(ArgPos - 1);
416   const auto RemoveArgFix = FixItHint::CreateRemoval(
417       SourceRange(exprLocEnd(LHSArg, Result),
418                   exprLocEnd(ArgToRemove, Result).getLocWithOffset(-1)));
419   Diag << RemoveArgFix;
420 }
421 
422 static void renameFunc(StringRef NewFuncName,
423                        const MatchFinder::MatchResult &Result,
424                        DiagnosticBuilder &Diag) {
425   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
426   int FuncNameLength =
427       FunctionExpr->getDirectCallee()->getIdentifier()->getLength();
428   SourceRange FuncNameRange(
429       FunctionExpr->getBeginLoc(),
430       FunctionExpr->getBeginLoc().getLocWithOffset(FuncNameLength - 1));
431 
432   const auto FuncNameFix =
433       FixItHint::CreateReplacement(FuncNameRange, NewFuncName);
434   Diag << FuncNameFix;
435 }
436 
437 static void renameMemcpy(StringRef Name, bool IsCopy, bool IsSafe,
438                          const MatchFinder::MatchResult &Result,
439                          DiagnosticBuilder &Diag) {
440   SmallString<10> NewFuncName;
441   NewFuncName = (Name[0] != 'w') ? "str" : "wcs";
442   NewFuncName += IsCopy ? "cpy" : "ncpy";
443   NewFuncName += IsSafe ? "_s" : "";
444   renameFunc(NewFuncName, Result, Diag);
445 }
446 
447 static void insertDestCapacityArg(bool IsOverflows, StringRef Name,
448                                   const MatchFinder::MatchResult &Result,
449                                   DiagnosticBuilder &Diag) {
450   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
451   SmallString<64> NewSecondArg;
452 
453   if (int DestLength = getDestCapacity(Result)) {
454     NewSecondArg = Twine(IsOverflows ? DestLength + 1 : DestLength).str();
455   } else {
456     NewSecondArg =
457         (Twine(exprToStr(getDestCapacityExpr(Result), Result)) +
458          (IsOverflows ? (!isInjectUL(Result) ? " + 1" : " + 1UL") : ""))
459             .str();
460   }
461 
462   NewSecondArg += ", ";
463   const auto InsertNewArgFix = FixItHint::CreateInsertion(
464       FunctionExpr->getArg(1)->getBeginLoc(), NewSecondArg);
465   Diag << InsertNewArgFix;
466 }
467 
468 static void insertNullTerminatorExpr(StringRef Name,
469                                      const MatchFinder::MatchResult &Result,
470                                      DiagnosticBuilder &Diag) {
471   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
472   int FuncLocStartColumn = Result.SourceManager->getPresumedColumnNumber(
473       FunctionExpr->getBeginLoc());
474   SourceRange SpaceRange(
475       FunctionExpr->getBeginLoc().getLocWithOffset(-FuncLocStartColumn + 1),
476       FunctionExpr->getBeginLoc());
477   StringRef SpaceBeforeStmtStr = Lexer::getSourceText(
478       CharSourceRange::getCharRange(SpaceRange), *Result.SourceManager,
479       Result.Context->getLangOpts(), 0);
480 
481   SmallString<128> NewAddNullTermExprStr;
482   NewAddNullTermExprStr =
483       (Twine('\n') + SpaceBeforeStmtStr +
484        exprToStr(Result.Nodes.getNodeAs<Expr>(DestExprName), Result) + "[" +
485        exprToStr(Result.Nodes.getNodeAs<Expr>(LengthExprName), Result) +
486        "] = " + ((Name[0] != 'w') ? "\'\\0\';" : "L\'\\0\';"))
487           .str();
488 
489   const auto AddNullTerminatorExprFix = FixItHint::CreateInsertion(
490       exprLocEnd(FunctionExpr, Result).getLocWithOffset(1),
491       NewAddNullTermExprStr);
492   Diag << AddNullTerminatorExprFix;
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // Checker logic with the matchers.
497 //===----------------------------------------------------------------------===//
498 
499 NotNullTerminatedResultCheck::NotNullTerminatedResultCheck(
500     StringRef Name, ClangTidyContext *Context)
501     : ClangTidyCheck(Name, Context),
502       WantToUseSafeFunctions(Options.get("WantToUseSafeFunctions", true)) {}
503 
504 void NotNullTerminatedResultCheck::storeOptions(
505     ClangTidyOptions::OptionMap &Opts) {
506   Options.store(Opts, "WantToUseSafeFunctions", WantToUseSafeFunctions);
507 }
508 
509 void NotNullTerminatedResultCheck::registerPPCallbacks(
510     const SourceManager &SM, Preprocessor *pp, Preprocessor *ModuleExpanderPP) {
511   PP = pp;
512 }
513 
514 namespace {
515 AST_MATCHER_P(Expr, hasDefinition, ast_matchers::internal::Matcher<Expr>,
516               InnerMatcher) {
517   const Expr *SimpleNode = &Node;
518   SimpleNode = SimpleNode->IgnoreParenImpCasts();
519 
520   if (InnerMatcher.matches(*SimpleNode, Finder, Builder))
521     return true;
522 
523   auto DREHasInit = ignoringImpCasts(
524       declRefExpr(to(varDecl(hasInitializer(ignoringImpCasts(InnerMatcher))))));
525 
526   if (DREHasInit.matches(*SimpleNode, Finder, Builder))
527     return true;
528 
529   const char *const VarDeclName = "variable-declaration";
530   auto DREHasDefinition = ignoringImpCasts(declRefExpr(
531       allOf(to(varDecl().bind(VarDeclName)),
532             hasAncestor(compoundStmt(hasDescendant(binaryOperator(
533                 hasLHS(declRefExpr(to(varDecl(equalsBoundNode(VarDeclName))))),
534                 hasRHS(ignoringImpCasts(InnerMatcher)))))))));
535 
536   if (DREHasDefinition.matches(*SimpleNode, Finder, Builder))
537     return true;
538 
539   return false;
540 }
541 } // namespace
542 
543 void NotNullTerminatedResultCheck::registerMatchers(MatchFinder *Finder) {
544   auto IncOp =
545       binaryOperator(hasOperatorName("+"),
546                      hasEitherOperand(ignoringParenImpCasts(integerLiteral())));
547 
548   auto DecOp =
549       binaryOperator(hasOperatorName("-"),
550                      hasEitherOperand(ignoringParenImpCasts(integerLiteral())));
551 
552   auto HasIncOp = anyOf(ignoringImpCasts(IncOp), hasDescendant(IncOp));
553   auto HasDecOp = anyOf(ignoringImpCasts(DecOp), hasDescendant(DecOp));
554 
555   auto Container = ignoringImpCasts(cxxMemberCallExpr(hasDescendant(declRefExpr(
556       hasType(hasUnqualifiedDesugaredType(recordType(hasDeclaration(recordDecl(
557           hasAnyName("::std::vector", "::std::list", "::std::deque"))))))))));
558 
559   auto StringTy = type(hasUnqualifiedDesugaredType(recordType(
560       hasDeclaration(cxxRecordDecl(hasName("::std::basic_string"))))));
561 
562   auto AnyOfStringTy =
563       anyOf(hasType(StringTy), hasType(qualType(pointsTo(StringTy))));
564 
565   auto CharTyArray = hasType(qualType(hasCanonicalType(
566       arrayType(hasElementType(isAnyCharacter())).bind(DestArrayTyName))));
567 
568   auto CharTyPointer = hasType(
569       qualType(hasCanonicalType(pointerType(pointee(isAnyCharacter())))));
570 
571   auto AnyOfCharTy = anyOf(CharTyArray, CharTyPointer);
572 
573   //===--------------------------------------------------------------------===//
574   // The following six cases match problematic length expressions.
575   //===--------------------------------------------------------------------===//
576 
577   // - Example:  char src[] = "foo";       strlen(src);
578   auto Strlen =
579       callExpr(callee(functionDecl(hasAnyName("::strlen", "::wcslen"))))
580           .bind(WrongLengthExprName);
581 
582   // - Example:  std::string str = "foo";  str.size();
583   auto SizeOrLength =
584       cxxMemberCallExpr(
585           allOf(on(expr(AnyOfStringTy).bind("Foo")),
586                 has(memberExpr(member(hasAnyName("size", "length"))))))
587           .bind(WrongLengthExprName);
588 
589   // - Example:  char src[] = "foo";       sizeof(src);
590   auto SizeOfCharExpr = unaryExprOrTypeTraitExpr(has(expr(AnyOfCharTy)));
591 
592   auto WrongLength =
593       ignoringImpCasts(anyOf(Strlen, SizeOrLength, hasDescendant(Strlen),
594                              hasDescendant(SizeOrLength)));
595 
596   // - Example:  length = strlen(src);
597   auto DREWithoutInc =
598       ignoringImpCasts(declRefExpr(to(varDecl(hasInitializer(WrongLength)))));
599 
600   auto AnyOfCallOrDREWithoutInc = anyOf(DREWithoutInc, WrongLength);
601 
602   // - Example:  int getLength(const char *str) { return strlen(str); }
603   auto CallExprReturnWithoutInc = ignoringImpCasts(callExpr(callee(functionDecl(
604       hasBody(has(returnStmt(hasReturnValue(AnyOfCallOrDREWithoutInc))))))));
605 
606   // - Example:  int length = getLength(src);
607   auto DREHasReturnWithoutInc = ignoringImpCasts(
608       declRefExpr(to(varDecl(hasInitializer(CallExprReturnWithoutInc)))));
609 
610   auto AnyOfWrongLengthInit =
611       anyOf(WrongLength, AnyOfCallOrDREWithoutInc, CallExprReturnWithoutInc,
612             DREHasReturnWithoutInc);
613 
614   //===--------------------------------------------------------------------===//
615   // The following five cases match the 'destination' array length's
616   // expression which is used in 'memcpy()' and 'memmove()' matchers.
617   //===--------------------------------------------------------------------===//
618 
619   // Note: Sometimes the size of char is explicitly written out.
620   auto SizeExpr = anyOf(SizeOfCharExpr, integerLiteral(equals(1)));
621 
622   auto MallocLengthExpr = allOf(
623       callee(functionDecl(
624           hasAnyName("::alloca", "::calloc", "malloc", "realloc"))),
625       hasAnyArgument(allOf(unless(SizeExpr), expr().bind(DestMallocExprName))));
626 
627   // - Example:  (char *)malloc(length);
628   auto DestMalloc = anyOf(callExpr(MallocLengthExpr),
629                           hasDescendant(callExpr(MallocLengthExpr)));
630 
631   // - Example:  new char[length];
632   auto DestCXXNewExpr = ignoringImpCasts(
633       cxxNewExpr(hasArraySize(expr().bind(DestMallocExprName))));
634 
635   auto AnyOfDestInit = anyOf(DestMalloc, DestCXXNewExpr);
636 
637   // - Example:  char dest[13];  or  char dest[length];
638   auto DestArrayTyDecl = declRefExpr(
639       to(anyOf(varDecl(CharTyArray).bind(DestVarDeclName),
640                varDecl(hasInitializer(AnyOfDestInit)).bind(DestVarDeclName))));
641 
642   // - Example:  foo[bar[baz]].qux; (or just ParmVarDecl)
643   auto DestUnknownDecl =
644       declRefExpr(allOf(to(varDecl(AnyOfCharTy).bind(DestVarDeclName)),
645                         expr().bind(UnknownDestName)))
646           .bind(DestExprName);
647 
648   auto AnyOfDestDecl = ignoringImpCasts(
649       anyOf(allOf(hasDefinition(anyOf(AnyOfDestInit, DestArrayTyDecl,
650                                       hasDescendant(DestArrayTyDecl))),
651                   expr().bind(DestExprName)),
652             anyOf(DestUnknownDecl, hasDescendant(DestUnknownDecl))));
653 
654   auto NullTerminatorExpr = binaryOperator(
655       hasLHS(anyOf(hasDescendant(declRefExpr(to(varDecl(
656                        equalsBoundNode(std::string(DestVarDeclName)))))),
657                    hasDescendant(declRefExpr(
658                        equalsBoundNode(std::string(DestExprName)))))),
659       hasRHS(ignoringImpCasts(
660           anyOf(characterLiteral(equals(0U)), integerLiteral(equals(0))))));
661 
662   auto SrcDecl = declRefExpr(
663       allOf(to(decl().bind(SrcVarDeclName)),
664             anyOf(hasAncestor(cxxMemberCallExpr().bind(SrcExprName)),
665                   expr().bind(SrcExprName))));
666 
667   auto AnyOfSrcDecl =
668       ignoringImpCasts(anyOf(stringLiteral().bind(SrcExprName),
669                              hasDescendant(stringLiteral().bind(SrcExprName)),
670                              SrcDecl, hasDescendant(SrcDecl)));
671 
672   //===--------------------------------------------------------------------===//
673   // Match the problematic function calls.
674   //===--------------------------------------------------------------------===//
675 
676   struct CallContext {
677     CallContext(StringRef Name, Optional<unsigned> DestinationPos,
678                 Optional<unsigned> SourcePos, unsigned LengthPos,
679                 bool WithIncrease)
680         : Name(Name), DestinationPos(DestinationPos), SourcePos(SourcePos),
681           LengthPos(LengthPos), WithIncrease(WithIncrease){};
682 
683     StringRef Name;
684     Optional<unsigned> DestinationPos;
685     Optional<unsigned> SourcePos;
686     unsigned LengthPos;
687     bool WithIncrease;
688   };
689 
690   auto MatchDestination = [=](CallContext CC) {
691     return hasArgument(*CC.DestinationPos,
692                        allOf(AnyOfDestDecl,
693                              unless(hasAncestor(compoundStmt(
694                                  hasDescendant(NullTerminatorExpr)))),
695                              unless(Container)));
696   };
697 
698   auto MatchSource = [=](CallContext CC) {
699     return hasArgument(*CC.SourcePos, AnyOfSrcDecl);
700   };
701 
702   auto MatchGivenLength = [=](CallContext CC) {
703     return hasArgument(
704         CC.LengthPos,
705         allOf(
706             anyOf(
707                 ignoringImpCasts(integerLiteral().bind(WrongLengthExprName)),
708                 allOf(unless(hasDefinition(SizeOfCharExpr)),
709                       allOf(CC.WithIncrease
710                                 ? ignoringImpCasts(hasDefinition(HasIncOp))
711                                 : ignoringImpCasts(allOf(
712                                       unless(hasDefinition(HasIncOp)),
713                                       anyOf(hasDefinition(binaryOperator().bind(
714                                                 UnknownLengthName)),
715                                             hasDefinition(anything())))),
716                             AnyOfWrongLengthInit))),
717             expr().bind(LengthExprName)));
718   };
719 
720   auto MatchCall = [=](CallContext CC) {
721     std::string CharHandlerFuncName = "::" + CC.Name.str();
722 
723     // Try to match with 'wchar_t' based function calls.
724     std::string WcharHandlerFuncName =
725         "::" + (CC.Name.startswith("mem") ? "w" + CC.Name.str()
726                                           : "wcs" + CC.Name.substr(3).str());
727 
728     return allOf(callee(functionDecl(
729                      hasAnyName(CharHandlerFuncName, WcharHandlerFuncName))),
730                  MatchGivenLength(CC));
731   };
732 
733   auto Match = [=](CallContext CC) {
734     if (CC.DestinationPos && CC.SourcePos)
735       return allOf(MatchCall(CC), MatchDestination(CC), MatchSource(CC));
736 
737     if (CC.DestinationPos && !CC.SourcePos)
738       return allOf(MatchCall(CC), MatchDestination(CC),
739                    hasArgument(*CC.DestinationPos, anything()));
740 
741     if (!CC.DestinationPos && CC.SourcePos)
742       return allOf(MatchCall(CC), MatchSource(CC),
743                    hasArgument(*CC.SourcePos, anything()));
744 
745     llvm_unreachable("Unhandled match");
746   };
747 
748   // void *memcpy(void *dest, const void *src, size_t count)
749   auto Memcpy = Match({"memcpy", 0, 1, 2, false});
750 
751   // errno_t memcpy_s(void *dest, size_t ds, const void *src, size_t count)
752   auto Memcpy_s = Match({"memcpy_s", 0, 2, 3, false});
753 
754   // void *memchr(const void *src, int c, size_t count)
755   auto Memchr = Match({"memchr", None, 0, 2, false});
756 
757   // void *memmove(void *dest, const void *src, size_t count)
758   auto Memmove = Match({"memmove", 0, 1, 2, false});
759 
760   // errno_t memmove_s(void *dest, size_t ds, const void *src, size_t count)
761   auto Memmove_s = Match({"memmove_s", 0, 2, 3, false});
762 
763   // int strncmp(const char *str1, const char *str2, size_t count);
764   auto StrncmpRHS = Match({"strncmp", None, 1, 2, true});
765   auto StrncmpLHS = Match({"strncmp", None, 0, 2, true});
766 
767   // size_t strxfrm(char *dest, const char *src, size_t count);
768   auto Strxfrm = Match({"strxfrm", 0, 1, 2, false});
769 
770   // errno_t strerror_s(char *buffer, size_t bufferSize, int errnum);
771   auto Strerror_s = Match({"strerror_s", 0, None, 1, false});
772 
773   auto AnyOfMatchers = anyOf(Memcpy, Memcpy_s, Memmove, Memmove_s, StrncmpRHS,
774                              StrncmpLHS, Strxfrm, Strerror_s);
775 
776   Finder->addMatcher(callExpr(AnyOfMatchers).bind(FunctionExprName), this);
777 
778   // Need to remove the CastExpr from 'memchr()' as 'strchr()' returns 'char *'.
779   Finder->addMatcher(
780       callExpr(Memchr,
781                unless(hasAncestor(castExpr(unless(implicitCastExpr())))))
782           .bind(FunctionExprName),
783       this);
784   Finder->addMatcher(
785       castExpr(allOf(unless(implicitCastExpr()),
786                      has(callExpr(Memchr).bind(FunctionExprName))))
787           .bind(CastExprName),
788       this);
789 }
790 
791 void NotNullTerminatedResultCheck::check(
792     const MatchFinder::MatchResult &Result) {
793   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
794   if (FunctionExpr->getBeginLoc().isMacroID())
795     return;
796 
797   if (WantToUseSafeFunctions && PP->isMacroDefined("__STDC_LIB_EXT1__")) {
798     Optional<bool> AreSafeFunctionsWanted;
799 
800     Preprocessor::macro_iterator It = PP->macro_begin();
801     while (It != PP->macro_end() && !AreSafeFunctionsWanted.hasValue()) {
802       if (It->first->getName() == "__STDC_WANT_LIB_EXT1__") {
803         const auto *MI = PP->getMacroInfo(It->first);
804         const auto &T = MI->tokens().back();
805         StringRef ValueStr = StringRef(T.getLiteralData(), T.getLength());
806         llvm::APInt IntValue;
807         ValueStr.getAsInteger(10, IntValue);
808         AreSafeFunctionsWanted = IntValue.getZExtValue();
809       }
810 
811       ++It;
812     }
813 
814     if (AreSafeFunctionsWanted.hasValue())
815       UseSafeFunctions = AreSafeFunctionsWanted.getValue();
816   }
817 
818   StringRef Name = FunctionExpr->getDirectCallee()->getName();
819   if (Name.startswith("mem") || Name.startswith("wmem"))
820     memoryHandlerFunctionFix(Name, Result);
821   else if (Name == "strerror_s")
822     strerror_sFix(Result);
823   else if (Name.endswith("ncmp"))
824     ncmpFix(Name, Result);
825   else if (Name.endswith("xfrm"))
826     xfrmFix(Name, Result);
827 }
828 
829 void NotNullTerminatedResultCheck::memoryHandlerFunctionFix(
830     StringRef Name, const MatchFinder::MatchResult &Result) {
831   if (isCorrectGivenLength(Result))
832     return;
833 
834   if (Name.endswith("chr")) {
835     memchrFix(Name, Result);
836     return;
837   }
838 
839   if ((Name.contains("cpy") || Name.contains("move")) &&
840       (isDestAndSrcEquals(Result) || isFixedGivenLengthAndUnknownSrc(Result)))
841     return;
842 
843   auto Diag =
844       diag(Result.Nodes.getNodeAs<CallExpr>(FunctionExprName)->getBeginLoc(),
845            "the result from calling '%0' is not null-terminated")
846       << Name;
847 
848   if (Name.endswith("cpy")) {
849     memcpyFix(Name, Result, Diag);
850   } else if (Name.endswith("cpy_s")) {
851     memcpy_sFix(Name, Result, Diag);
852   } else if (Name.endswith("move")) {
853     memmoveFix(Name, Result, Diag);
854   } else if (Name.endswith("move_s")) {
855     isDestCapacityFix(Result, Diag);
856     lengthArgHandle(LengthHandleKind::Increase, Result, Diag);
857   }
858 }
859 
860 void NotNullTerminatedResultCheck::memcpyFix(
861     StringRef Name, const MatchFinder::MatchResult &Result,
862     DiagnosticBuilder &Diag) {
863   bool IsOverflows = isDestCapacityFix(Result, Diag);
864   bool IsDestFixed = isDestExprFix(Result, Diag);
865 
866   bool IsCopy =
867       isGivenLengthEqualToSrcLength(Result) || isDestBasedOnGivenLength(Result);
868 
869   bool IsSafe = UseSafeFunctions && IsOverflows && isKnownDest(Result) &&
870                 !isDestBasedOnGivenLength(Result);
871 
872   bool IsDestLengthNotRequired =
873       IsSafe && getLangOpts().CPlusPlus &&
874       Result.Nodes.getNodeAs<ArrayType>(DestArrayTyName) && !IsDestFixed;
875 
876   renameMemcpy(Name, IsCopy, IsSafe, Result, Diag);
877 
878   if (IsSafe && !IsDestLengthNotRequired)
879     insertDestCapacityArg(IsOverflows, Name, Result, Diag);
880 
881   if (IsCopy)
882     removeArg(2, Result, Diag);
883 
884   if (!IsCopy && !IsSafe)
885     insertNullTerminatorExpr(Name, Result, Diag);
886 }
887 
888 void NotNullTerminatedResultCheck::memcpy_sFix(
889     StringRef Name, const MatchFinder::MatchResult &Result,
890     DiagnosticBuilder &Diag) {
891   bool IsOverflows = isDestCapacityFix(Result, Diag);
892   bool IsDestFixed = isDestExprFix(Result, Diag);
893 
894   bool RemoveDestLength = getLangOpts().CPlusPlus &&
895                           Result.Nodes.getNodeAs<ArrayType>(DestArrayTyName) &&
896                           !IsDestFixed;
897   bool IsCopy = isGivenLengthEqualToSrcLength(Result);
898   bool IsSafe = IsOverflows;
899 
900   renameMemcpy(Name, IsCopy, IsSafe, Result, Diag);
901 
902   if (!IsSafe || (IsSafe && RemoveDestLength))
903     removeArg(1, Result, Diag);
904   else if (IsOverflows && isKnownDest(Result))
905     lengthArgPosHandle(1, LengthHandleKind::Increase, Result, Diag);
906 
907   if (IsCopy)
908     removeArg(3, Result, Diag);
909 
910   if (!IsCopy && !IsSafe)
911     insertNullTerminatorExpr(Name, Result, Diag);
912 }
913 
914 void NotNullTerminatedResultCheck::memchrFix(
915     StringRef Name, const MatchFinder::MatchResult &Result) {
916   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
917   if (const auto GivenCL = dyn_cast<CharacterLiteral>(FunctionExpr->getArg(1)))
918     if (GivenCL->getValue() != 0)
919       return;
920 
921   auto Diag = diag(FunctionExpr->getArg(2)->IgnoreParenCasts()->getBeginLoc(),
922                    "the length is too short to include the null terminator");
923 
924   if (const auto *CastExpr = Result.Nodes.getNodeAs<Expr>(CastExprName)) {
925     const auto CastRemoveFix = FixItHint::CreateRemoval(
926         SourceRange(CastExpr->getBeginLoc(),
927                     FunctionExpr->getBeginLoc().getLocWithOffset(-1)));
928     Diag << CastRemoveFix;
929   }
930 
931   StringRef NewFuncName = (Name[0] != 'w') ? "strchr" : "wcschr";
932   renameFunc(NewFuncName, Result, Diag);
933   removeArg(2, Result, Diag);
934 }
935 
936 void NotNullTerminatedResultCheck::memmoveFix(
937     StringRef Name, const MatchFinder::MatchResult &Result,
938     DiagnosticBuilder &Diag) {
939   bool IsOverflows = isDestCapacityFix(Result, Diag);
940 
941   if (UseSafeFunctions && isKnownDest(Result)) {
942     renameFunc((Name[0] != 'w') ? "memmove_s" : "wmemmove_s", Result, Diag);
943     insertDestCapacityArg(IsOverflows, Name, Result, Diag);
944   }
945 
946   lengthArgHandle(LengthHandleKind::Increase, Result, Diag);
947 }
948 
949 void NotNullTerminatedResultCheck::strerror_sFix(
950     const MatchFinder::MatchResult &Result) {
951   auto Diag =
952       diag(Result.Nodes.getNodeAs<CallExpr>(FunctionExprName)->getBeginLoc(),
953            "the result from calling 'strerror_s' is not null-terminated and "
954            "missing the last character of the error message");
955 
956   isDestCapacityFix(Result, Diag);
957   lengthArgHandle(LengthHandleKind::Increase, Result, Diag);
958 }
959 
960 void NotNullTerminatedResultCheck::ncmpFix(
961     StringRef Name, const MatchFinder::MatchResult &Result) {
962   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
963   const Expr *FirstArgExpr = FunctionExpr->getArg(0)->IgnoreImpCasts();
964   const Expr *SecondArgExpr = FunctionExpr->getArg(1)->IgnoreImpCasts();
965   bool IsLengthTooLong = false;
966 
967   if (const CallExpr *StrlenExpr = getStrlenExpr(Result)) {
968     const Expr *LengthExprArg = StrlenExpr->getArg(0);
969     StringRef FirstExprStr = exprToStr(FirstArgExpr, Result).trim();
970     StringRef SecondExprStr = exprToStr(SecondArgExpr, Result).trim();
971     StringRef LengthArgStr = exprToStr(LengthExprArg, Result).trim();
972     IsLengthTooLong =
973         LengthArgStr == FirstExprStr || LengthArgStr == SecondExprStr;
974   } else {
975     int SrcLength =
976         getLength(Result.Nodes.getNodeAs<Expr>(SrcExprName), Result);
977     int GivenLength = getGivenLength(Result);
978     if (SrcLength != 0 && GivenLength != 0)
979       IsLengthTooLong = GivenLength > SrcLength;
980   }
981 
982   if (!IsLengthTooLong && !isStringDataAndLength(Result))
983     return;
984 
985   auto Diag = diag(FunctionExpr->getArg(2)->IgnoreParenCasts()->getBeginLoc(),
986                    "comparison length is too long and might lead to a "
987                    "buffer overflow");
988 
989   lengthArgHandle(LengthHandleKind::Decrease, Result, Diag);
990 }
991 
992 void NotNullTerminatedResultCheck::xfrmFix(
993     StringRef Name, const MatchFinder::MatchResult &Result) {
994   if (!isDestCapacityOverflows(Result))
995     return;
996 
997   auto Diag =
998       diag(Result.Nodes.getNodeAs<CallExpr>(FunctionExprName)->getBeginLoc(),
999            "the result from calling '%0' is not null-terminated")
1000       << Name;
1001 
1002   isDestCapacityFix(Result, Diag);
1003   lengthArgHandle(LengthHandleKind::Increase, Result, Diag);
1004 }
1005 
1006 } // namespace bugprone
1007 } // namespace tidy
1008 } // namespace clang
1009