xref: /llvm-project/clang-tools-extra/clang-tidy/bugprone/NotNullTerminatedResultCheck.cpp (revision 71f557355ddaea358c43b151de3a0e045aaa0863)
1 //===--- NotNullTerminatedResultCheck.cpp - clang-tidy ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "NotNullTerminatedResultCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Frontend/CompilerInstance.h"
13 #include "clang/Lex/Lexer.h"
14 #include "clang/Lex/PPCallbacks.h"
15 #include "clang/Lex/Preprocessor.h"
16 #include <optional>
17 
18 using namespace clang::ast_matchers;
19 
20 namespace clang {
21 namespace tidy {
22 namespace bugprone {
23 
24 constexpr llvm::StringLiteral FunctionExprName = "FunctionExpr";
25 constexpr llvm::StringLiteral CastExprName = "CastExpr";
26 constexpr llvm::StringLiteral UnknownDestName = "UnknownDest";
27 constexpr llvm::StringLiteral DestArrayTyName = "DestArrayTy";
28 constexpr llvm::StringLiteral DestVarDeclName = "DestVarDecl";
29 constexpr llvm::StringLiteral DestMallocExprName = "DestMalloc";
30 constexpr llvm::StringLiteral DestExprName = "DestExpr";
31 constexpr llvm::StringLiteral SrcVarDeclName = "SrcVarDecl";
32 constexpr llvm::StringLiteral SrcExprName = "SrcExpr";
33 constexpr llvm::StringLiteral LengthExprName = "LengthExpr";
34 constexpr llvm::StringLiteral WrongLengthExprName = "WrongLength";
35 constexpr llvm::StringLiteral UnknownLengthName = "UnknownLength";
36 
37 enum class LengthHandleKind { Increase, Decrease };
38 
39 namespace {
40 static Preprocessor *PP;
41 } // namespace
42 
43 // Returns the expression of destination's capacity which is part of a
44 // 'VariableArrayType', 'ConstantArrayTypeLoc' or an argument of a 'malloc()'
45 // family function call.
46 static const Expr *getDestCapacityExpr(const MatchFinder::MatchResult &Result) {
47   if (const auto *DestMalloc = Result.Nodes.getNodeAs<Expr>(DestMallocExprName))
48     return DestMalloc;
49 
50   if (const auto *DestVAT =
51           Result.Nodes.getNodeAs<VariableArrayType>(DestArrayTyName))
52     return DestVAT->getSizeExpr();
53 
54   if (const auto *DestVD = Result.Nodes.getNodeAs<VarDecl>(DestVarDeclName))
55     if (const TypeLoc DestTL = DestVD->getTypeSourceInfo()->getTypeLoc())
56       if (const auto DestCTL = DestTL.getAs<ConstantArrayTypeLoc>())
57         return DestCTL.getSizeExpr();
58 
59   return nullptr;
60 }
61 
62 // Returns the length of \p E as an 'IntegerLiteral' or a 'StringLiteral'
63 // without the null-terminator.
64 static unsigned getLength(const Expr *E,
65                           const MatchFinder::MatchResult &Result) {
66   if (!E)
67     return 0;
68 
69   Expr::EvalResult Length;
70   E = E->IgnoreImpCasts();
71 
72   if (const auto *LengthDRE = dyn_cast<DeclRefExpr>(E))
73     if (const auto *LengthVD = dyn_cast<VarDecl>(LengthDRE->getDecl()))
74       if (!isa<ParmVarDecl>(LengthVD))
75         if (const Expr *LengthInit = LengthVD->getInit())
76           if (LengthInit->EvaluateAsInt(Length, *Result.Context))
77             return Length.Val.getInt().getZExtValue();
78 
79   if (const auto *LengthIL = dyn_cast<IntegerLiteral>(E))
80     return LengthIL->getValue().getZExtValue();
81 
82   if (const auto *StrDRE = dyn_cast<DeclRefExpr>(E))
83     if (const auto *StrVD = dyn_cast<VarDecl>(StrDRE->getDecl()))
84       if (const Expr *StrInit = StrVD->getInit())
85         if (const auto *StrSL =
86                 dyn_cast<StringLiteral>(StrInit->IgnoreImpCasts()))
87           return StrSL->getLength();
88 
89   if (const auto *SrcSL = dyn_cast<StringLiteral>(E))
90     return SrcSL->getLength();
91 
92   return 0;
93 }
94 
95 // Returns the capacity of the destination array.
96 // For example in 'char dest[13]; memcpy(dest, ...)' it returns 13.
97 static int getDestCapacity(const MatchFinder::MatchResult &Result) {
98   if (const auto *DestCapacityExpr = getDestCapacityExpr(Result))
99     return getLength(DestCapacityExpr, Result);
100 
101   return 0;
102 }
103 
104 // Returns the 'strlen()' if it is the given length.
105 static const CallExpr *getStrlenExpr(const MatchFinder::MatchResult &Result) {
106   if (const auto *StrlenExpr =
107           Result.Nodes.getNodeAs<CallExpr>(WrongLengthExprName))
108     if (const Decl *D = StrlenExpr->getCalleeDecl())
109       if (const FunctionDecl *FD = D->getAsFunction())
110         if (const IdentifierInfo *II = FD->getIdentifier())
111           if (II->isStr("strlen") || II->isStr("wcslen"))
112             return StrlenExpr;
113 
114   return nullptr;
115 }
116 
117 // Returns the length which is given in the memory/string handler function.
118 // For example in 'memcpy(dest, "foobar", 3)' it returns 3.
119 static int getGivenLength(const MatchFinder::MatchResult &Result) {
120   if (Result.Nodes.getNodeAs<Expr>(UnknownLengthName))
121     return 0;
122 
123   if (int Length =
124           getLength(Result.Nodes.getNodeAs<Expr>(WrongLengthExprName), Result))
125     return Length;
126 
127   if (int Length =
128           getLength(Result.Nodes.getNodeAs<Expr>(LengthExprName), Result))
129     return Length;
130 
131   // Special case, for example 'strlen("foo")'.
132   if (const CallExpr *StrlenCE = getStrlenExpr(Result))
133     if (const Expr *Arg = StrlenCE->getArg(0)->IgnoreImpCasts())
134       if (int ArgLength = getLength(Arg, Result))
135         return ArgLength;
136 
137   return 0;
138 }
139 
140 // Returns a string representation of \p E.
141 static StringRef exprToStr(const Expr *E,
142                            const MatchFinder::MatchResult &Result) {
143   if (!E)
144     return "";
145 
146   return Lexer::getSourceText(
147       CharSourceRange::getTokenRange(E->getSourceRange()),
148       *Result.SourceManager, Result.Context->getLangOpts(), nullptr);
149 }
150 
151 // Returns the proper token based end location of \p E.
152 static SourceLocation exprLocEnd(const Expr *E,
153                                  const MatchFinder::MatchResult &Result) {
154   return Lexer::getLocForEndOfToken(E->getEndLoc(), 0, *Result.SourceManager,
155                                     Result.Context->getLangOpts());
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Rewrite decision helper functions.
160 //===----------------------------------------------------------------------===//
161 
162 // Increment by integer '1' can result in overflow if it is the maximal value.
163 // After that it would be extended to 'size_t' and its value would be wrong,
164 // therefore we have to inject '+ 1UL' instead.
165 static bool isInjectUL(const MatchFinder::MatchResult &Result) {
166   return getGivenLength(Result) == std::numeric_limits<int>::max();
167 }
168 
169 // If the capacity of the destination array is unknown it is denoted as unknown.
170 static bool isKnownDest(const MatchFinder::MatchResult &Result) {
171   return !Result.Nodes.getNodeAs<Expr>(UnknownDestName);
172 }
173 
174 // True if the capacity of the destination array is based on the given length,
175 // therefore we assume that it cannot overflow (e.g. 'malloc(given_length + 1)'
176 static bool isDestBasedOnGivenLength(const MatchFinder::MatchResult &Result) {
177   StringRef DestCapacityExprStr =
178       exprToStr(getDestCapacityExpr(Result), Result).trim();
179   StringRef LengthExprStr =
180       exprToStr(Result.Nodes.getNodeAs<Expr>(LengthExprName), Result).trim();
181 
182   return DestCapacityExprStr != "" && LengthExprStr != "" &&
183          DestCapacityExprStr.contains(LengthExprStr);
184 }
185 
186 // Writing and reading from the same memory cannot remove the null-terminator.
187 static bool isDestAndSrcEquals(const MatchFinder::MatchResult &Result) {
188   if (const auto *DestDRE = Result.Nodes.getNodeAs<DeclRefExpr>(DestExprName))
189     if (const auto *SrcDRE = Result.Nodes.getNodeAs<DeclRefExpr>(SrcExprName))
190       return DestDRE->getDecl()->getCanonicalDecl() ==
191              SrcDRE->getDecl()->getCanonicalDecl();
192 
193   return false;
194 }
195 
196 // For example 'std::string str = "foo"; memcpy(dst, str.data(), str.length())'.
197 static bool isStringDataAndLength(const MatchFinder::MatchResult &Result) {
198   const auto *DestExpr =
199       Result.Nodes.getNodeAs<CXXMemberCallExpr>(DestExprName);
200   const auto *SrcExpr = Result.Nodes.getNodeAs<CXXMemberCallExpr>(SrcExprName);
201   const auto *LengthExpr =
202       Result.Nodes.getNodeAs<CXXMemberCallExpr>(WrongLengthExprName);
203 
204   StringRef DestStr = "", SrcStr = "", LengthStr = "";
205   if (DestExpr)
206     if (const CXXMethodDecl *DestMD = DestExpr->getMethodDecl())
207       DestStr = DestMD->getName();
208 
209   if (SrcExpr)
210     if (const CXXMethodDecl *SrcMD = SrcExpr->getMethodDecl())
211       SrcStr = SrcMD->getName();
212 
213   if (LengthExpr)
214     if (const CXXMethodDecl *LengthMD = LengthExpr->getMethodDecl())
215       LengthStr = LengthMD->getName();
216 
217   return (LengthStr == "length" || LengthStr == "size") &&
218          (SrcStr == "data" || DestStr == "data");
219 }
220 
221 static bool
222 isGivenLengthEqualToSrcLength(const MatchFinder::MatchResult &Result) {
223   if (Result.Nodes.getNodeAs<Expr>(UnknownLengthName))
224     return false;
225 
226   if (isStringDataAndLength(Result))
227     return true;
228 
229   int GivenLength = getGivenLength(Result);
230   int SrcLength = getLength(Result.Nodes.getNodeAs<Expr>(SrcExprName), Result);
231 
232   if (GivenLength != 0 && SrcLength != 0 && GivenLength == SrcLength)
233     return true;
234 
235   if (const auto *LengthExpr = Result.Nodes.getNodeAs<Expr>(LengthExprName))
236     if (isa<BinaryOperator>(LengthExpr->IgnoreParenImpCasts()))
237       return false;
238 
239   // Check the strlen()'s argument's 'VarDecl' is equal to the source 'VarDecl'.
240   if (const CallExpr *StrlenCE = getStrlenExpr(Result))
241     if (const auto *ArgDRE =
242             dyn_cast<DeclRefExpr>(StrlenCE->getArg(0)->IgnoreImpCasts()))
243       if (const auto *SrcVD = Result.Nodes.getNodeAs<VarDecl>(SrcVarDeclName))
244         return dyn_cast<VarDecl>(ArgDRE->getDecl()) == SrcVD;
245 
246   return false;
247 }
248 
249 static bool isCorrectGivenLength(const MatchFinder::MatchResult &Result) {
250   if (Result.Nodes.getNodeAs<Expr>(UnknownLengthName))
251     return false;
252 
253   return !isGivenLengthEqualToSrcLength(Result);
254 }
255 
256 // If we rewrite the function call we need to create extra space to hold the
257 // null terminator. The new necessary capacity overflows without that '+ 1'
258 // size and we need to correct the given capacity.
259 static bool isDestCapacityOverflows(const MatchFinder::MatchResult &Result) {
260   if (!isKnownDest(Result))
261     return true;
262 
263   const Expr *DestCapacityExpr = getDestCapacityExpr(Result);
264   int DestCapacity = getLength(DestCapacityExpr, Result);
265   int GivenLength = getGivenLength(Result);
266 
267   if (GivenLength != 0 && DestCapacity != 0)
268     return isGivenLengthEqualToSrcLength(Result) && DestCapacity == GivenLength;
269 
270   // Assume that the destination array's capacity cannot overflow if the
271   // expression of the memory allocation contains '+ 1'.
272   StringRef DestCapacityExprStr = exprToStr(DestCapacityExpr, Result);
273   if (DestCapacityExprStr.contains("+1") || DestCapacityExprStr.contains("+ 1"))
274     return false;
275 
276   return true;
277 }
278 
279 static bool
280 isFixedGivenLengthAndUnknownSrc(const MatchFinder::MatchResult &Result) {
281   if (Result.Nodes.getNodeAs<IntegerLiteral>(WrongLengthExprName))
282     return !getLength(Result.Nodes.getNodeAs<Expr>(SrcExprName), Result);
283 
284   return false;
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // Code injection functions.
289 //===----------------------------------------------------------------------===//
290 
291 // Increase or decrease \p LengthExpr by one.
292 static void lengthExprHandle(const Expr *LengthExpr,
293                              LengthHandleKind LengthHandle,
294                              const MatchFinder::MatchResult &Result,
295                              DiagnosticBuilder &Diag) {
296   LengthExpr = LengthExpr->IgnoreParenImpCasts();
297 
298   // See whether we work with a macro.
299   bool IsMacroDefinition = false;
300   StringRef LengthExprStr = exprToStr(LengthExpr, Result);
301   Preprocessor::macro_iterator It = PP->macro_begin();
302   while (It != PP->macro_end() && !IsMacroDefinition) {
303     if (It->first->getName() == LengthExprStr)
304       IsMacroDefinition = true;
305 
306     ++It;
307   }
308 
309   // Try to obtain an 'IntegerLiteral' and adjust it.
310   if (!IsMacroDefinition) {
311     if (const auto *LengthIL = dyn_cast<IntegerLiteral>(LengthExpr)) {
312       size_t NewLength = LengthIL->getValue().getZExtValue() +
313                          (LengthHandle == LengthHandleKind::Increase
314                               ? (isInjectUL(Result) ? 1UL : 1)
315                               : -1);
316 
317       const auto NewLengthFix = FixItHint::CreateReplacement(
318           LengthIL->getSourceRange(),
319           (Twine(NewLength) + (isInjectUL(Result) ? "UL" : "")).str());
320       Diag << NewLengthFix;
321       return;
322     }
323   }
324 
325   // Try to obtain and remove the '+ 1' string as a decrement fix.
326   const auto *BO = dyn_cast<BinaryOperator>(LengthExpr);
327   if (BO && BO->getOpcode() == BO_Add &&
328       LengthHandle == LengthHandleKind::Decrease) {
329     const Expr *LhsExpr = BO->getLHS()->IgnoreImpCasts();
330     const Expr *RhsExpr = BO->getRHS()->IgnoreImpCasts();
331 
332     if (const auto *LhsIL = dyn_cast<IntegerLiteral>(LhsExpr)) {
333       if (LhsIL->getValue().getZExtValue() == 1) {
334         Diag << FixItHint::CreateRemoval(
335             {LhsIL->getBeginLoc(),
336              RhsExpr->getBeginLoc().getLocWithOffset(-1)});
337         return;
338       }
339     }
340 
341     if (const auto *RhsIL = dyn_cast<IntegerLiteral>(RhsExpr)) {
342       if (RhsIL->getValue().getZExtValue() == 1) {
343         Diag << FixItHint::CreateRemoval(
344             {LhsExpr->getEndLoc().getLocWithOffset(1), RhsIL->getEndLoc()});
345         return;
346       }
347     }
348   }
349 
350   // Try to inject the '+ 1'/'- 1' string.
351   bool NeedInnerParen = BO && BO->getOpcode() != BO_Add;
352 
353   if (NeedInnerParen)
354     Diag << FixItHint::CreateInsertion(LengthExpr->getBeginLoc(), "(");
355 
356   SmallString<8> Injection;
357   if (NeedInnerParen)
358     Injection += ')';
359   Injection += LengthHandle == LengthHandleKind::Increase ? " + 1" : " - 1";
360   if (isInjectUL(Result))
361     Injection += "UL";
362 
363   Diag << FixItHint::CreateInsertion(exprLocEnd(LengthExpr, Result), Injection);
364 }
365 
366 static void lengthArgHandle(LengthHandleKind LengthHandle,
367                             const MatchFinder::MatchResult &Result,
368                             DiagnosticBuilder &Diag) {
369   const auto *LengthExpr = Result.Nodes.getNodeAs<Expr>(LengthExprName);
370   lengthExprHandle(LengthExpr, LengthHandle, Result, Diag);
371 }
372 
373 static void lengthArgPosHandle(unsigned ArgPos, LengthHandleKind LengthHandle,
374                                const MatchFinder::MatchResult &Result,
375                                DiagnosticBuilder &Diag) {
376   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
377   lengthExprHandle(FunctionExpr->getArg(ArgPos), LengthHandle, Result, Diag);
378 }
379 
380 // The string handler functions are only operates with plain 'char'/'wchar_t'
381 // without 'unsigned/signed', therefore we need to cast it.
382 static bool isDestExprFix(const MatchFinder::MatchResult &Result,
383                           DiagnosticBuilder &Diag) {
384   const auto *Dest = Result.Nodes.getNodeAs<Expr>(DestExprName);
385   if (!Dest)
386     return false;
387 
388   std::string TempTyStr = Dest->getType().getAsString();
389   StringRef TyStr = TempTyStr;
390   if (TyStr.startswith("char") || TyStr.startswith("wchar_t"))
391     return false;
392 
393   Diag << FixItHint::CreateInsertion(Dest->getBeginLoc(), "(char *)");
394   return true;
395 }
396 
397 // If the destination array is the same length as the given length we have to
398 // increase the capacity by one to create space for the null terminator.
399 static bool isDestCapacityFix(const MatchFinder::MatchResult &Result,
400                               DiagnosticBuilder &Diag) {
401   bool IsOverflows = isDestCapacityOverflows(Result);
402   if (IsOverflows)
403     if (const Expr *CapacityExpr = getDestCapacityExpr(Result))
404       lengthExprHandle(CapacityExpr, LengthHandleKind::Increase, Result, Diag);
405 
406   return IsOverflows;
407 }
408 
409 static void removeArg(int ArgPos, const MatchFinder::MatchResult &Result,
410                       DiagnosticBuilder &Diag) {
411   // This is the following structure: (src, '\0', strlen(src))
412   //                     ArgToRemove:             ~~~~~~~~~~~
413   //                          LHSArg:       ~~~~
414   //                    RemoveArgFix:           ~~~~~~~~~~~~~
415   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
416   const Expr *ArgToRemove = FunctionExpr->getArg(ArgPos);
417   const Expr *LHSArg = FunctionExpr->getArg(ArgPos - 1);
418   const auto RemoveArgFix = FixItHint::CreateRemoval(
419       SourceRange(exprLocEnd(LHSArg, Result),
420                   exprLocEnd(ArgToRemove, Result).getLocWithOffset(-1)));
421   Diag << RemoveArgFix;
422 }
423 
424 static void renameFunc(StringRef NewFuncName,
425                        const MatchFinder::MatchResult &Result,
426                        DiagnosticBuilder &Diag) {
427   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
428   int FuncNameLength =
429       FunctionExpr->getDirectCallee()->getIdentifier()->getLength();
430   SourceRange FuncNameRange(
431       FunctionExpr->getBeginLoc(),
432       FunctionExpr->getBeginLoc().getLocWithOffset(FuncNameLength - 1));
433 
434   const auto FuncNameFix =
435       FixItHint::CreateReplacement(FuncNameRange, NewFuncName);
436   Diag << FuncNameFix;
437 }
438 
439 static void renameMemcpy(StringRef Name, bool IsCopy, bool IsSafe,
440                          const MatchFinder::MatchResult &Result,
441                          DiagnosticBuilder &Diag) {
442   SmallString<10> NewFuncName;
443   NewFuncName = (Name[0] != 'w') ? "str" : "wcs";
444   NewFuncName += IsCopy ? "cpy" : "ncpy";
445   NewFuncName += IsSafe ? "_s" : "";
446   renameFunc(NewFuncName, Result, Diag);
447 }
448 
449 static void insertDestCapacityArg(bool IsOverflows, StringRef Name,
450                                   const MatchFinder::MatchResult &Result,
451                                   DiagnosticBuilder &Diag) {
452   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
453   SmallString<64> NewSecondArg;
454 
455   if (int DestLength = getDestCapacity(Result)) {
456     NewSecondArg = Twine(IsOverflows ? DestLength + 1 : DestLength).str();
457   } else {
458     NewSecondArg =
459         (Twine(exprToStr(getDestCapacityExpr(Result), Result)) +
460          (IsOverflows ? (!isInjectUL(Result) ? " + 1" : " + 1UL") : ""))
461             .str();
462   }
463 
464   NewSecondArg += ", ";
465   const auto InsertNewArgFix = FixItHint::CreateInsertion(
466       FunctionExpr->getArg(1)->getBeginLoc(), NewSecondArg);
467   Diag << InsertNewArgFix;
468 }
469 
470 static void insertNullTerminatorExpr(StringRef Name,
471                                      const MatchFinder::MatchResult &Result,
472                                      DiagnosticBuilder &Diag) {
473   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
474   int FuncLocStartColumn = Result.SourceManager->getPresumedColumnNumber(
475       FunctionExpr->getBeginLoc());
476   SourceRange SpaceRange(
477       FunctionExpr->getBeginLoc().getLocWithOffset(-FuncLocStartColumn + 1),
478       FunctionExpr->getBeginLoc());
479   StringRef SpaceBeforeStmtStr = Lexer::getSourceText(
480       CharSourceRange::getCharRange(SpaceRange), *Result.SourceManager,
481       Result.Context->getLangOpts(), nullptr);
482 
483   SmallString<128> NewAddNullTermExprStr;
484   NewAddNullTermExprStr =
485       (Twine('\n') + SpaceBeforeStmtStr +
486        exprToStr(Result.Nodes.getNodeAs<Expr>(DestExprName), Result) + "[" +
487        exprToStr(Result.Nodes.getNodeAs<Expr>(LengthExprName), Result) +
488        "] = " + ((Name[0] != 'w') ? "\'\\0\';" : "L\'\\0\';"))
489           .str();
490 
491   const auto AddNullTerminatorExprFix = FixItHint::CreateInsertion(
492       exprLocEnd(FunctionExpr, Result).getLocWithOffset(1),
493       NewAddNullTermExprStr);
494   Diag << AddNullTerminatorExprFix;
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // Checker logic with the matchers.
499 //===----------------------------------------------------------------------===//
500 
501 NotNullTerminatedResultCheck::NotNullTerminatedResultCheck(
502     StringRef Name, ClangTidyContext *Context)
503     : ClangTidyCheck(Name, Context),
504       WantToUseSafeFunctions(Options.get("WantToUseSafeFunctions", true)) {}
505 
506 void NotNullTerminatedResultCheck::storeOptions(
507     ClangTidyOptions::OptionMap &Opts) {
508   Options.store(Opts, "WantToUseSafeFunctions", WantToUseSafeFunctions);
509 }
510 
511 void NotNullTerminatedResultCheck::registerPPCallbacks(
512     const SourceManager &SM, Preprocessor *Pp, Preprocessor *ModuleExpanderPP) {
513   PP = Pp;
514 }
515 
516 namespace {
517 AST_MATCHER_P(Expr, hasDefinition, ast_matchers::internal::Matcher<Expr>,
518               InnerMatcher) {
519   const Expr *SimpleNode = &Node;
520   SimpleNode = SimpleNode->IgnoreParenImpCasts();
521 
522   if (InnerMatcher.matches(*SimpleNode, Finder, Builder))
523     return true;
524 
525   auto DREHasInit = ignoringImpCasts(
526       declRefExpr(to(varDecl(hasInitializer(ignoringImpCasts(InnerMatcher))))));
527 
528   if (DREHasInit.matches(*SimpleNode, Finder, Builder))
529     return true;
530 
531   const char *const VarDeclName = "variable-declaration";
532   auto DREHasDefinition = ignoringImpCasts(declRefExpr(
533       allOf(to(varDecl().bind(VarDeclName)),
534             hasAncestor(compoundStmt(hasDescendant(binaryOperator(
535                 hasLHS(declRefExpr(to(varDecl(equalsBoundNode(VarDeclName))))),
536                 hasRHS(ignoringImpCasts(InnerMatcher)))))))));
537 
538   if (DREHasDefinition.matches(*SimpleNode, Finder, Builder))
539     return true;
540 
541   return false;
542 }
543 } // namespace
544 
545 void NotNullTerminatedResultCheck::registerMatchers(MatchFinder *Finder) {
546   auto IncOp =
547       binaryOperator(hasOperatorName("+"),
548                      hasEitherOperand(ignoringParenImpCasts(integerLiteral())));
549 
550   auto DecOp =
551       binaryOperator(hasOperatorName("-"),
552                      hasEitherOperand(ignoringParenImpCasts(integerLiteral())));
553 
554   auto HasIncOp = anyOf(ignoringImpCasts(IncOp), hasDescendant(IncOp));
555   auto HasDecOp = anyOf(ignoringImpCasts(DecOp), hasDescendant(DecOp));
556 
557   auto Container = ignoringImpCasts(cxxMemberCallExpr(hasDescendant(declRefExpr(
558       hasType(hasUnqualifiedDesugaredType(recordType(hasDeclaration(recordDecl(
559           hasAnyName("::std::vector", "::std::list", "::std::deque"))))))))));
560 
561   auto StringTy = type(hasUnqualifiedDesugaredType(recordType(
562       hasDeclaration(cxxRecordDecl(hasName("::std::basic_string"))))));
563 
564   auto AnyOfStringTy =
565       anyOf(hasType(StringTy), hasType(qualType(pointsTo(StringTy))));
566 
567   auto CharTyArray = hasType(qualType(hasCanonicalType(
568       arrayType(hasElementType(isAnyCharacter())).bind(DestArrayTyName))));
569 
570   auto CharTyPointer = hasType(
571       qualType(hasCanonicalType(pointerType(pointee(isAnyCharacter())))));
572 
573   auto AnyOfCharTy = anyOf(CharTyArray, CharTyPointer);
574 
575   //===--------------------------------------------------------------------===//
576   // The following six cases match problematic length expressions.
577   //===--------------------------------------------------------------------===//
578 
579   // - Example:  char src[] = "foo";       strlen(src);
580   auto Strlen =
581       callExpr(callee(functionDecl(hasAnyName("::strlen", "::wcslen"))))
582           .bind(WrongLengthExprName);
583 
584   // - Example:  std::string str = "foo";  str.size();
585   auto SizeOrLength =
586       cxxMemberCallExpr(
587           allOf(on(expr(AnyOfStringTy).bind("Foo")),
588                 has(memberExpr(member(hasAnyName("size", "length"))))))
589           .bind(WrongLengthExprName);
590 
591   // - Example:  char src[] = "foo";       sizeof(src);
592   auto SizeOfCharExpr = unaryExprOrTypeTraitExpr(has(expr(AnyOfCharTy)));
593 
594   auto WrongLength =
595       ignoringImpCasts(anyOf(Strlen, SizeOrLength, hasDescendant(Strlen),
596                              hasDescendant(SizeOrLength)));
597 
598   // - Example:  length = strlen(src);
599   auto DREWithoutInc =
600       ignoringImpCasts(declRefExpr(to(varDecl(hasInitializer(WrongLength)))));
601 
602   auto AnyOfCallOrDREWithoutInc = anyOf(DREWithoutInc, WrongLength);
603 
604   // - Example:  int getLength(const char *str) { return strlen(str); }
605   auto CallExprReturnWithoutInc = ignoringImpCasts(callExpr(callee(functionDecl(
606       hasBody(has(returnStmt(hasReturnValue(AnyOfCallOrDREWithoutInc))))))));
607 
608   // - Example:  int length = getLength(src);
609   auto DREHasReturnWithoutInc = ignoringImpCasts(
610       declRefExpr(to(varDecl(hasInitializer(CallExprReturnWithoutInc)))));
611 
612   auto AnyOfWrongLengthInit =
613       anyOf(WrongLength, AnyOfCallOrDREWithoutInc, CallExprReturnWithoutInc,
614             DREHasReturnWithoutInc);
615 
616   //===--------------------------------------------------------------------===//
617   // The following five cases match the 'destination' array length's
618   // expression which is used in 'memcpy()' and 'memmove()' matchers.
619   //===--------------------------------------------------------------------===//
620 
621   // Note: Sometimes the size of char is explicitly written out.
622   auto SizeExpr = anyOf(SizeOfCharExpr, integerLiteral(equals(1)));
623 
624   auto MallocLengthExpr = allOf(
625       callee(functionDecl(
626           hasAnyName("::alloca", "::calloc", "malloc", "realloc"))),
627       hasAnyArgument(allOf(unless(SizeExpr), expr().bind(DestMallocExprName))));
628 
629   // - Example:  (char *)malloc(length);
630   auto DestMalloc = anyOf(callExpr(MallocLengthExpr),
631                           hasDescendant(callExpr(MallocLengthExpr)));
632 
633   // - Example:  new char[length];
634   auto DestCXXNewExpr = ignoringImpCasts(
635       cxxNewExpr(hasArraySize(expr().bind(DestMallocExprName))));
636 
637   auto AnyOfDestInit = anyOf(DestMalloc, DestCXXNewExpr);
638 
639   // - Example:  char dest[13];  or  char dest[length];
640   auto DestArrayTyDecl = declRefExpr(
641       to(anyOf(varDecl(CharTyArray).bind(DestVarDeclName),
642                varDecl(hasInitializer(AnyOfDestInit)).bind(DestVarDeclName))));
643 
644   // - Example:  foo[bar[baz]].qux; (or just ParmVarDecl)
645   auto DestUnknownDecl =
646       declRefExpr(allOf(to(varDecl(AnyOfCharTy).bind(DestVarDeclName)),
647                         expr().bind(UnknownDestName)))
648           .bind(DestExprName);
649 
650   auto AnyOfDestDecl = ignoringImpCasts(
651       anyOf(allOf(hasDefinition(anyOf(AnyOfDestInit, DestArrayTyDecl,
652                                       hasDescendant(DestArrayTyDecl))),
653                   expr().bind(DestExprName)),
654             anyOf(DestUnknownDecl, hasDescendant(DestUnknownDecl))));
655 
656   auto NullTerminatorExpr = binaryOperator(
657       hasLHS(anyOf(hasDescendant(declRefExpr(to(varDecl(
658                        equalsBoundNode(std::string(DestVarDeclName)))))),
659                    hasDescendant(declRefExpr(
660                        equalsBoundNode(std::string(DestExprName)))))),
661       hasRHS(ignoringImpCasts(
662           anyOf(characterLiteral(equals(0U)), integerLiteral(equals(0))))));
663 
664   auto SrcDecl = declRefExpr(
665       allOf(to(decl().bind(SrcVarDeclName)),
666             anyOf(hasAncestor(cxxMemberCallExpr().bind(SrcExprName)),
667                   expr().bind(SrcExprName))));
668 
669   auto AnyOfSrcDecl =
670       ignoringImpCasts(anyOf(stringLiteral().bind(SrcExprName),
671                              hasDescendant(stringLiteral().bind(SrcExprName)),
672                              SrcDecl, hasDescendant(SrcDecl)));
673 
674   //===--------------------------------------------------------------------===//
675   // Match the problematic function calls.
676   //===--------------------------------------------------------------------===//
677 
678   struct CallContext {
679     CallContext(StringRef Name, Optional<unsigned> DestinationPos,
680                 Optional<unsigned> SourcePos, unsigned LengthPos,
681                 bool WithIncrease)
682         : Name(Name), DestinationPos(DestinationPos), SourcePos(SourcePos),
683           LengthPos(LengthPos), WithIncrease(WithIncrease){};
684 
685     StringRef Name;
686     Optional<unsigned> DestinationPos;
687     Optional<unsigned> SourcePos;
688     unsigned LengthPos;
689     bool WithIncrease;
690   };
691 
692   auto MatchDestination = [=](CallContext CC) {
693     return hasArgument(*CC.DestinationPos,
694                        allOf(AnyOfDestDecl,
695                              unless(hasAncestor(compoundStmt(
696                                  hasDescendant(NullTerminatorExpr)))),
697                              unless(Container)));
698   };
699 
700   auto MatchSource = [=](CallContext CC) {
701     return hasArgument(*CC.SourcePos, AnyOfSrcDecl);
702   };
703 
704   auto MatchGivenLength = [=](CallContext CC) {
705     return hasArgument(
706         CC.LengthPos,
707         allOf(
708             anyOf(
709                 ignoringImpCasts(integerLiteral().bind(WrongLengthExprName)),
710                 allOf(unless(hasDefinition(SizeOfCharExpr)),
711                       allOf(CC.WithIncrease
712                                 ? ignoringImpCasts(hasDefinition(HasIncOp))
713                                 : ignoringImpCasts(allOf(
714                                       unless(hasDefinition(HasIncOp)),
715                                       anyOf(hasDefinition(binaryOperator().bind(
716                                                 UnknownLengthName)),
717                                             hasDefinition(anything())))),
718                             AnyOfWrongLengthInit))),
719             expr().bind(LengthExprName)));
720   };
721 
722   auto MatchCall = [=](CallContext CC) {
723     std::string CharHandlerFuncName = "::" + CC.Name.str();
724 
725     // Try to match with 'wchar_t' based function calls.
726     std::string WcharHandlerFuncName =
727         "::" + (CC.Name.startswith("mem") ? "w" + CC.Name.str()
728                                           : "wcs" + CC.Name.substr(3).str());
729 
730     return allOf(callee(functionDecl(
731                      hasAnyName(CharHandlerFuncName, WcharHandlerFuncName))),
732                  MatchGivenLength(CC));
733   };
734 
735   auto Match = [=](CallContext CC) {
736     if (CC.DestinationPos && CC.SourcePos)
737       return allOf(MatchCall(CC), MatchDestination(CC), MatchSource(CC));
738 
739     if (CC.DestinationPos && !CC.SourcePos)
740       return allOf(MatchCall(CC), MatchDestination(CC),
741                    hasArgument(*CC.DestinationPos, anything()));
742 
743     if (!CC.DestinationPos && CC.SourcePos)
744       return allOf(MatchCall(CC), MatchSource(CC),
745                    hasArgument(*CC.SourcePos, anything()));
746 
747     llvm_unreachable("Unhandled match");
748   };
749 
750   // void *memcpy(void *dest, const void *src, size_t count)
751   auto Memcpy = Match({"memcpy", 0, 1, 2, false});
752 
753   // errno_t memcpy_s(void *dest, size_t ds, const void *src, size_t count)
754   auto MemcpyS = Match({"memcpy_s", 0, 2, 3, false});
755 
756   // void *memchr(const void *src, int c, size_t count)
757   auto Memchr = Match({"memchr", std::nullopt, 0, 2, false});
758 
759   // void *memmove(void *dest, const void *src, size_t count)
760   auto Memmove = Match({"memmove", 0, 1, 2, false});
761 
762   // errno_t memmove_s(void *dest, size_t ds, const void *src, size_t count)
763   auto MemmoveS = Match({"memmove_s", 0, 2, 3, false});
764 
765   // int strncmp(const char *str1, const char *str2, size_t count);
766   auto StrncmpRHS = Match({"strncmp", std::nullopt, 1, 2, true});
767   auto StrncmpLHS = Match({"strncmp", std::nullopt, 0, 2, true});
768 
769   // size_t strxfrm(char *dest, const char *src, size_t count);
770   auto Strxfrm = Match({"strxfrm", 0, 1, 2, false});
771 
772   // errno_t strerror_s(char *buffer, size_t bufferSize, int errnum);
773   auto StrerrorS = Match({"strerror_s", 0, std::nullopt, 1, false});
774 
775   auto AnyOfMatchers = anyOf(Memcpy, MemcpyS, Memmove, MemmoveS, StrncmpRHS,
776                              StrncmpLHS, Strxfrm, StrerrorS);
777 
778   Finder->addMatcher(callExpr(AnyOfMatchers).bind(FunctionExprName), this);
779 
780   // Need to remove the CastExpr from 'memchr()' as 'strchr()' returns 'char *'.
781   Finder->addMatcher(
782       callExpr(Memchr,
783                unless(hasAncestor(castExpr(unless(implicitCastExpr())))))
784           .bind(FunctionExprName),
785       this);
786   Finder->addMatcher(
787       castExpr(allOf(unless(implicitCastExpr()),
788                      has(callExpr(Memchr).bind(FunctionExprName))))
789           .bind(CastExprName),
790       this);
791 }
792 
793 void NotNullTerminatedResultCheck::check(
794     const MatchFinder::MatchResult &Result) {
795   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
796   if (FunctionExpr->getBeginLoc().isMacroID())
797     return;
798 
799   if (WantToUseSafeFunctions && PP->isMacroDefined("__STDC_LIB_EXT1__")) {
800     Optional<bool> AreSafeFunctionsWanted;
801 
802     Preprocessor::macro_iterator It = PP->macro_begin();
803     while (It != PP->macro_end() && !AreSafeFunctionsWanted) {
804       if (It->first->getName() == "__STDC_WANT_LIB_EXT1__") {
805         const auto *MI = PP->getMacroInfo(It->first);
806         // PP->getMacroInfo() returns nullptr if macro has no definition.
807         if (MI) {
808           const auto &T = MI->tokens().back();
809           if (T.isLiteral() && T.getLiteralData()) {
810             StringRef ValueStr = StringRef(T.getLiteralData(), T.getLength());
811             llvm::APInt IntValue;
812             ValueStr.getAsInteger(10, IntValue);
813             AreSafeFunctionsWanted = IntValue.getZExtValue();
814           }
815         }
816       }
817 
818       ++It;
819     }
820 
821     if (AreSafeFunctionsWanted)
822       UseSafeFunctions = *AreSafeFunctionsWanted;
823   }
824 
825   StringRef Name = FunctionExpr->getDirectCallee()->getName();
826   if (Name.startswith("mem") || Name.startswith("wmem"))
827     memoryHandlerFunctionFix(Name, Result);
828   else if (Name == "strerror_s")
829     strerror_sFix(Result);
830   else if (Name.endswith("ncmp"))
831     ncmpFix(Name, Result);
832   else if (Name.endswith("xfrm"))
833     xfrmFix(Name, Result);
834 }
835 
836 void NotNullTerminatedResultCheck::memoryHandlerFunctionFix(
837     StringRef Name, const MatchFinder::MatchResult &Result) {
838   if (isCorrectGivenLength(Result))
839     return;
840 
841   if (Name.endswith("chr")) {
842     memchrFix(Name, Result);
843     return;
844   }
845 
846   if ((Name.contains("cpy") || Name.contains("move")) &&
847       (isDestAndSrcEquals(Result) || isFixedGivenLengthAndUnknownSrc(Result)))
848     return;
849 
850   auto Diag =
851       diag(Result.Nodes.getNodeAs<CallExpr>(FunctionExprName)->getBeginLoc(),
852            "the result from calling '%0' is not null-terminated")
853       << Name;
854 
855   if (Name.endswith("cpy")) {
856     memcpyFix(Name, Result, Diag);
857   } else if (Name.endswith("cpy_s")) {
858     memcpy_sFix(Name, Result, Diag);
859   } else if (Name.endswith("move")) {
860     memmoveFix(Name, Result, Diag);
861   } else if (Name.endswith("move_s")) {
862     isDestCapacityFix(Result, Diag);
863     lengthArgHandle(LengthHandleKind::Increase, Result, Diag);
864   }
865 }
866 
867 void NotNullTerminatedResultCheck::memcpyFix(
868     StringRef Name, const MatchFinder::MatchResult &Result,
869     DiagnosticBuilder &Diag) {
870   bool IsOverflows = isDestCapacityFix(Result, Diag);
871   bool IsDestFixed = isDestExprFix(Result, Diag);
872 
873   bool IsCopy =
874       isGivenLengthEqualToSrcLength(Result) || isDestBasedOnGivenLength(Result);
875 
876   bool IsSafe = UseSafeFunctions && IsOverflows && isKnownDest(Result) &&
877                 !isDestBasedOnGivenLength(Result);
878 
879   bool IsDestLengthNotRequired =
880       IsSafe && getLangOpts().CPlusPlus &&
881       Result.Nodes.getNodeAs<ArrayType>(DestArrayTyName) && !IsDestFixed;
882 
883   renameMemcpy(Name, IsCopy, IsSafe, Result, Diag);
884 
885   if (IsSafe && !IsDestLengthNotRequired)
886     insertDestCapacityArg(IsOverflows, Name, Result, Diag);
887 
888   if (IsCopy)
889     removeArg(2, Result, Diag);
890 
891   if (!IsCopy && !IsSafe)
892     insertNullTerminatorExpr(Name, Result, Diag);
893 }
894 
895 void NotNullTerminatedResultCheck::memcpy_sFix(
896     StringRef Name, const MatchFinder::MatchResult &Result,
897     DiagnosticBuilder &Diag) {
898   bool IsOverflows = isDestCapacityFix(Result, Diag);
899   bool IsDestFixed = isDestExprFix(Result, Diag);
900 
901   bool RemoveDestLength = getLangOpts().CPlusPlus &&
902                           Result.Nodes.getNodeAs<ArrayType>(DestArrayTyName) &&
903                           !IsDestFixed;
904   bool IsCopy = isGivenLengthEqualToSrcLength(Result);
905   bool IsSafe = IsOverflows;
906 
907   renameMemcpy(Name, IsCopy, IsSafe, Result, Diag);
908 
909   if (!IsSafe || (IsSafe && RemoveDestLength))
910     removeArg(1, Result, Diag);
911   else if (IsOverflows && isKnownDest(Result))
912     lengthArgPosHandle(1, LengthHandleKind::Increase, Result, Diag);
913 
914   if (IsCopy)
915     removeArg(3, Result, Diag);
916 
917   if (!IsCopy && !IsSafe)
918     insertNullTerminatorExpr(Name, Result, Diag);
919 }
920 
921 void NotNullTerminatedResultCheck::memchrFix(
922     StringRef Name, const MatchFinder::MatchResult &Result) {
923   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
924   if (const auto *GivenCL = dyn_cast<CharacterLiteral>(FunctionExpr->getArg(1)))
925     if (GivenCL->getValue() != 0)
926       return;
927 
928   auto Diag = diag(FunctionExpr->getArg(2)->IgnoreParenCasts()->getBeginLoc(),
929                    "the length is too short to include the null terminator");
930 
931   if (const auto *CastExpr = Result.Nodes.getNodeAs<Expr>(CastExprName)) {
932     const auto CastRemoveFix = FixItHint::CreateRemoval(
933         SourceRange(CastExpr->getBeginLoc(),
934                     FunctionExpr->getBeginLoc().getLocWithOffset(-1)));
935     Diag << CastRemoveFix;
936   }
937 
938   StringRef NewFuncName = (Name[0] != 'w') ? "strchr" : "wcschr";
939   renameFunc(NewFuncName, Result, Diag);
940   removeArg(2, Result, Diag);
941 }
942 
943 void NotNullTerminatedResultCheck::memmoveFix(
944     StringRef Name, const MatchFinder::MatchResult &Result,
945     DiagnosticBuilder &Diag) {
946   bool IsOverflows = isDestCapacityFix(Result, Diag);
947 
948   if (UseSafeFunctions && isKnownDest(Result)) {
949     renameFunc((Name[0] != 'w') ? "memmove_s" : "wmemmove_s", Result, Diag);
950     insertDestCapacityArg(IsOverflows, Name, Result, Diag);
951   }
952 
953   lengthArgHandle(LengthHandleKind::Increase, Result, Diag);
954 }
955 
956 void NotNullTerminatedResultCheck::strerror_sFix(
957     const MatchFinder::MatchResult &Result) {
958   auto Diag =
959       diag(Result.Nodes.getNodeAs<CallExpr>(FunctionExprName)->getBeginLoc(),
960            "the result from calling 'strerror_s' is not null-terminated and "
961            "missing the last character of the error message");
962 
963   isDestCapacityFix(Result, Diag);
964   lengthArgHandle(LengthHandleKind::Increase, Result, Diag);
965 }
966 
967 void NotNullTerminatedResultCheck::ncmpFix(
968     StringRef Name, const MatchFinder::MatchResult &Result) {
969   const auto *FunctionExpr = Result.Nodes.getNodeAs<CallExpr>(FunctionExprName);
970   const Expr *FirstArgExpr = FunctionExpr->getArg(0)->IgnoreImpCasts();
971   const Expr *SecondArgExpr = FunctionExpr->getArg(1)->IgnoreImpCasts();
972   bool IsLengthTooLong = false;
973 
974   if (const CallExpr *StrlenExpr = getStrlenExpr(Result)) {
975     const Expr *LengthExprArg = StrlenExpr->getArg(0);
976     StringRef FirstExprStr = exprToStr(FirstArgExpr, Result).trim();
977     StringRef SecondExprStr = exprToStr(SecondArgExpr, Result).trim();
978     StringRef LengthArgStr = exprToStr(LengthExprArg, Result).trim();
979     IsLengthTooLong =
980         LengthArgStr == FirstExprStr || LengthArgStr == SecondExprStr;
981   } else {
982     int SrcLength =
983         getLength(Result.Nodes.getNodeAs<Expr>(SrcExprName), Result);
984     int GivenLength = getGivenLength(Result);
985     if (SrcLength != 0 && GivenLength != 0)
986       IsLengthTooLong = GivenLength > SrcLength;
987   }
988 
989   if (!IsLengthTooLong && !isStringDataAndLength(Result))
990     return;
991 
992   auto Diag = diag(FunctionExpr->getArg(2)->IgnoreParenCasts()->getBeginLoc(),
993                    "comparison length is too long and might lead to a "
994                    "buffer overflow");
995 
996   lengthArgHandle(LengthHandleKind::Decrease, Result, Diag);
997 }
998 
999 void NotNullTerminatedResultCheck::xfrmFix(
1000     StringRef Name, const MatchFinder::MatchResult &Result) {
1001   if (!isDestCapacityOverflows(Result))
1002     return;
1003 
1004   auto Diag =
1005       diag(Result.Nodes.getNodeAs<CallExpr>(FunctionExprName)->getBeginLoc(),
1006            "the result from calling '%0' is not null-terminated")
1007       << Name;
1008 
1009   isDestCapacityFix(Result, Diag);
1010   lengthArgHandle(LengthHandleKind::Increase, Result, Diag);
1011 }
1012 
1013 } // namespace bugprone
1014 } // namespace tidy
1015 } // namespace clang
1016