xref: /llvm-project/clang-tools-extra/clang-tidy/modernize/UseStdNumbersCheck.cpp (revision a5f54175dcf120180c3d91bbc13062bbf8f42f61)
1 //===--- UseStdNumbersCheck.cpp - clang_tidy ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX_License_Identifier: Apache_2.0 WITH LLVM_exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UseStdNumbersCheck.h"
10 #include "../ClangTidyDiagnosticConsumer.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Decl.h"
13 #include "clang/AST/Expr.h"
14 #include "clang/AST/Stmt.h"
15 #include "clang/AST/Type.h"
16 #include "clang/ASTMatchers/ASTMatchFinder.h"
17 #include "clang/ASTMatchers/ASTMatchers.h"
18 #include "clang/ASTMatchers/ASTMatchersInternal.h"
19 #include "clang/ASTMatchers/ASTMatchersMacros.h"
20 #include "clang/Basic/Diagnostic.h"
21 #include "clang/Basic/LLVM.h"
22 #include "clang/Basic/LangOptions.h"
23 #include "clang/Basic/SourceLocation.h"
24 #include "clang/Basic/SourceManager.h"
25 #include "clang/Lex/Lexer.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/MathExtras.h"
31 #include <array>
32 #include <cmath>
33 #include <cstdint>
34 #include <cstdlib>
35 #include <initializer_list>
36 #include <string>
37 #include <tuple>
38 #include <utility>
39 
40 namespace {
41 using namespace clang::ast_matchers;
42 using clang::ast_matchers::internal::Matcher;
43 using llvm::StringRef;
44 
AST_MATCHER_P2(clang::FloatingLiteral,near,double,Value,double,DiffThreshold)45 AST_MATCHER_P2(clang::FloatingLiteral, near, double, Value, double,
46                DiffThreshold) {
47   return std::abs(Node.getValueAsApproximateDouble() - Value) < DiffThreshold;
48 }
49 
AST_MATCHER_P(clang::QualType,hasCanonicalTypeUnqualified,Matcher<clang::QualType>,InnerMatcher)50 AST_MATCHER_P(clang::QualType, hasCanonicalTypeUnqualified,
51               Matcher<clang::QualType>, InnerMatcher) {
52   return !Node.isNull() &&
53          InnerMatcher.matches(Node->getCanonicalTypeUnqualified(), Finder,
54                               Builder);
55 }
56 
AST_MATCHER(clang::QualType,isArithmetic)57 AST_MATCHER(clang::QualType, isArithmetic) {
58   return !Node.isNull() && Node->isArithmeticType();
59 }
AST_MATCHER(clang::QualType,isFloating)60 AST_MATCHER(clang::QualType, isFloating) {
61   return !Node.isNull() && Node->isFloatingType();
62 }
63 
AST_MATCHER_P(clang::Expr,anyOfExhaustive,std::vector<Matcher<clang::Stmt>>,Exprs)64 AST_MATCHER_P(clang::Expr, anyOfExhaustive, std::vector<Matcher<clang::Stmt>>,
65               Exprs) {
66   bool FoundMatch = false;
67   for (const auto &InnerMatcher : Exprs) {
68     clang::ast_matchers::internal::BoundNodesTreeBuilder Result = *Builder;
69     if (InnerMatcher.matches(Node, Finder, &Result)) {
70       *Builder = std::move(Result);
71       FoundMatch = true;
72     }
73   }
74   return FoundMatch;
75 }
76 
77 // Using this struct to store the 'DiffThreshold' config value to create the
78 // matchers without the need to pass 'DiffThreshold' into every matcher.
79 // 'DiffThreshold' is needed in the 'near' matcher, which is used for matching
80 // the literal of every constant and for formulas' subexpressions that look at
81 // literals.
82 struct MatchBuilder {
83   auto
ignoreParenAndArithmeticCasting__anon60bcb47f0111::MatchBuilder84   ignoreParenAndArithmeticCasting(const Matcher<clang::Expr> Matcher) const {
85     return expr(hasType(qualType(isArithmetic())), ignoringParenCasts(Matcher));
86   }
87 
ignoreParenAndFloatingCasting__anon60bcb47f0111::MatchBuilder88   auto ignoreParenAndFloatingCasting(const Matcher<clang::Expr> Matcher) const {
89     return expr(hasType(qualType(isFloating())), ignoringParenCasts(Matcher));
90   }
91 
matchMathCall__anon60bcb47f0111::MatchBuilder92   auto matchMathCall(const StringRef FunctionName,
93                      const Matcher<clang::Expr> ArgumentMatcher) const {
94     return expr(ignoreParenAndFloatingCasting(
95         callExpr(callee(functionDecl(hasName(FunctionName),
96                                      hasParameter(0, hasType(isArithmetic())))),
97                  hasArgument(0, ArgumentMatcher))));
98   }
99 
matchSqrt__anon60bcb47f0111::MatchBuilder100   auto matchSqrt(const Matcher<clang::Expr> ArgumentMatcher) const {
101     return matchMathCall("sqrt", ArgumentMatcher);
102   }
103 
104   // Used for top-level matchers (i.e. the match that replaces Val with its
105   // constant).
106   //
107   // E.g. The matcher of `std::numbers::pi` uses this matcher to look for
108   // floatLiterals that have the value of pi.
109   //
110   // If the match is for a top-level match, we only care about the literal.
matchFloatLiteralNear__anon60bcb47f0111::MatchBuilder111   auto matchFloatLiteralNear(const StringRef Constant, const double Val) const {
112     return expr(ignoreParenAndFloatingCasting(
113         floatLiteral(near(Val, DiffThreshold)).bind(Constant)));
114   }
115 
116   // Used for non-top-level matchers (i.e. matchers that are used as inner
117   // matchers for top-level matchers).
118   //
119   // E.g.: The matcher of `std::numbers::log2e` uses this matcher to check if
120   // `e` of `log2(e)` is declared constant and initialized with the value for
121   // eulers number.
122   //
123   // Here, we do care about literals and about DeclRefExprs to variable
124   // declarations that are constant and initialized with `Val`. This allows
125   // top-level matchers to see through declared constants for their inner
126   // matches like the `std::numbers::log2e` matcher.
matchFloatValueNear__anon60bcb47f0111::MatchBuilder127   auto matchFloatValueNear(const double Val) const {
128     const auto Float = floatLiteral(near(Val, DiffThreshold));
129 
130     const auto Dref = declRefExpr(
131         to(varDecl(hasType(qualType(isConstQualified(), isFloating())),
132                    hasInitializer(ignoreParenAndFloatingCasting(Float)))));
133     return expr(ignoreParenAndFloatingCasting(anyOf(Float, Dref)));
134   }
135 
matchValue__anon60bcb47f0111::MatchBuilder136   auto matchValue(const int64_t ValInt) const {
137     const auto Int =
138         expr(ignoreParenAndArithmeticCasting(integerLiteral(equals(ValInt))));
139     const auto Float = expr(ignoreParenAndFloatingCasting(
140         matchFloatValueNear(static_cast<double>(ValInt))));
141     const auto Dref = declRefExpr(to(varDecl(
142         hasType(qualType(isConstQualified(), isArithmetic())),
143         hasInitializer(expr(anyOf(ignoringImplicit(Int),
144                                   ignoreParenAndFloatingCasting(Float)))))));
145     return expr(anyOf(Int, Float, Dref));
146   }
147 
match1Div__anon60bcb47f0111::MatchBuilder148   auto match1Div(const Matcher<clang::Expr> Match) const {
149     return binaryOperator(hasOperatorName("/"), hasLHS(matchValue(1)),
150                           hasRHS(Match));
151   }
152 
matchEuler__anon60bcb47f0111::MatchBuilder153   auto matchEuler() const {
154     return expr(anyOf(matchFloatValueNear(llvm::numbers::e),
155                       matchMathCall("exp", matchValue(1))));
156   }
matchEulerTopLevel__anon60bcb47f0111::MatchBuilder157   auto matchEulerTopLevel() const {
158     return expr(anyOf(matchFloatLiteralNear("e_literal", llvm::numbers::e),
159                       matchMathCall("exp", matchValue(1)).bind("e_pattern")))
160         .bind("e");
161   }
162 
matchLog2Euler__anon60bcb47f0111::MatchBuilder163   auto matchLog2Euler() const {
164     return expr(
165                anyOf(
166                    matchFloatLiteralNear("log2e_literal", llvm::numbers::log2e),
167                    matchMathCall("log2", matchEuler()).bind("log2e_pattern")))
168         .bind("log2e");
169   }
170 
matchLog10Euler__anon60bcb47f0111::MatchBuilder171   auto matchLog10Euler() const {
172     return expr(
173                anyOf(
174                    matchFloatLiteralNear("log10e_literal",
175                                          llvm::numbers::log10e),
176                    matchMathCall("log10", matchEuler()).bind("log10e_pattern")))
177         .bind("log10e");
178   }
179 
matchPi__anon60bcb47f0111::MatchBuilder180   auto matchPi() const { return matchFloatValueNear(llvm::numbers::pi); }
matchPiTopLevel__anon60bcb47f0111::MatchBuilder181   auto matchPiTopLevel() const {
182     return matchFloatLiteralNear("pi_literal", llvm::numbers::pi).bind("pi");
183   }
184 
matchEgamma__anon60bcb47f0111::MatchBuilder185   auto matchEgamma() const {
186     return matchFloatLiteralNear("egamma_literal", llvm::numbers::egamma)
187         .bind("egamma");
188   }
189 
matchInvPi__anon60bcb47f0111::MatchBuilder190   auto matchInvPi() const {
191     return expr(anyOf(matchFloatLiteralNear("inv_pi_literal",
192                                             llvm::numbers::inv_pi),
193                       match1Div(matchPi()).bind("inv_pi_pattern")))
194         .bind("inv_pi");
195   }
196 
matchInvSqrtPi__anon60bcb47f0111::MatchBuilder197   auto matchInvSqrtPi() const {
198     return expr(anyOf(
199                     matchFloatLiteralNear("inv_sqrtpi_literal",
200                                           llvm::numbers::inv_sqrtpi),
201                     match1Div(matchSqrt(matchPi())).bind("inv_sqrtpi_pattern")))
202         .bind("inv_sqrtpi");
203   }
204 
matchLn2__anon60bcb47f0111::MatchBuilder205   auto matchLn2() const {
206     return expr(anyOf(matchFloatLiteralNear("ln2_literal", llvm::numbers::ln2),
207                       matchMathCall("log", matchValue(2)).bind("ln2_pattern")))
208         .bind("ln2");
209   }
210 
machterLn10__anon60bcb47f0111::MatchBuilder211   auto machterLn10() const {
212     return expr(
213                anyOf(matchFloatLiteralNear("ln10_literal", llvm::numbers::ln10),
214                      matchMathCall("log", matchValue(10)).bind("ln10_pattern")))
215         .bind("ln10");
216   }
217 
matchSqrt2__anon60bcb47f0111::MatchBuilder218   auto matchSqrt2() const {
219     return expr(anyOf(matchFloatLiteralNear("sqrt2_literal",
220                                             llvm::numbers::sqrt2),
221                       matchSqrt(matchValue(2)).bind("sqrt2_pattern")))
222         .bind("sqrt2");
223   }
224 
matchSqrt3__anon60bcb47f0111::MatchBuilder225   auto matchSqrt3() const {
226     return expr(anyOf(matchFloatLiteralNear("sqrt3_literal",
227                                             llvm::numbers::sqrt3),
228                       matchSqrt(matchValue(3)).bind("sqrt3_pattern")))
229         .bind("sqrt3");
230   }
231 
matchInvSqrt3__anon60bcb47f0111::MatchBuilder232   auto matchInvSqrt3() const {
233     return expr(anyOf(matchFloatLiteralNear("inv_sqrt3_literal",
234                                             llvm::numbers::inv_sqrt3),
235                       match1Div(matchSqrt(matchValue(3)))
236                           .bind("inv_sqrt3_pattern")))
237         .bind("inv_sqrt3");
238   }
239 
matchPhi__anon60bcb47f0111::MatchBuilder240   auto matchPhi() const {
241     const auto PhiFormula = binaryOperator(
242         hasOperatorName("/"),
243         hasLHS(binaryOperator(
244             hasOperatorName("+"), hasEitherOperand(matchValue(1)),
245             hasEitherOperand(matchMathCall("sqrt", matchValue(5))))),
246         hasRHS(matchValue(2)));
247     return expr(anyOf(PhiFormula.bind("phi_pattern"),
248                       matchFloatLiteralNear("phi_literal", llvm::numbers::phi)))
249         .bind("phi");
250   }
251 
252   double DiffThreshold;
253 };
254 
getCode(const StringRef Constant,const bool IsFloat,const bool IsLongDouble)255 std::string getCode(const StringRef Constant, const bool IsFloat,
256                     const bool IsLongDouble) {
257   if (IsFloat) {
258     return ("std::numbers::" + Constant + "_v<float>").str();
259   }
260   if (IsLongDouble) {
261     return ("std::numbers::" + Constant + "_v<long double>").str();
262   }
263   return ("std::numbers::" + Constant).str();
264 }
265 
isRangeOfCompleteMacro(const clang::SourceRange & Range,const clang::SourceManager & SM,const clang::LangOptions & LO)266 bool isRangeOfCompleteMacro(const clang::SourceRange &Range,
267                             const clang::SourceManager &SM,
268                             const clang::LangOptions &LO) {
269   if (!Range.getBegin().isMacroID()) {
270     return false;
271   }
272   if (!clang::Lexer::isAtStartOfMacroExpansion(Range.getBegin(), SM, LO)) {
273     return false;
274   }
275 
276   if (!Range.getEnd().isMacroID()) {
277     return false;
278   }
279 
280   if (!clang::Lexer::isAtEndOfMacroExpansion(Range.getEnd(), SM, LO)) {
281     return false;
282   }
283 
284   return true;
285 }
286 
287 } // namespace
288 
289 namespace clang::tidy::modernize {
UseStdNumbersCheck(const StringRef Name,ClangTidyContext * const Context)290 UseStdNumbersCheck::UseStdNumbersCheck(const StringRef Name,
291                                        ClangTidyContext *const Context)
292     : ClangTidyCheck(Name, Context),
293       IncludeInserter(Options.getLocalOrGlobal("IncludeStyle",
294                                                utils::IncludeSorter::IS_LLVM),
295                       areDiagsSelfContained()),
296       DiffThresholdString{Options.get("DiffThreshold", "0.001")} {
297   if (DiffThresholdString.getAsDouble(DiffThreshold)) {
298     configurationDiag(
299         "Invalid DiffThreshold config value: '%0', expected a double")
300         << DiffThresholdString;
301     DiffThreshold = 0.001;
302   }
303 }
304 
registerMatchers(MatchFinder * const Finder)305 void UseStdNumbersCheck::registerMatchers(MatchFinder *const Finder) {
306   const auto Matches = MatchBuilder{DiffThreshold};
307   std::vector<Matcher<clang::Stmt>> ConstantMatchers = {
308       Matches.matchLog2Euler(),     Matches.matchLog10Euler(),
309       Matches.matchEulerTopLevel(), Matches.matchEgamma(),
310       Matches.matchInvSqrtPi(),     Matches.matchInvPi(),
311       Matches.matchPiTopLevel(),    Matches.matchLn2(),
312       Matches.machterLn10(),        Matches.matchSqrt2(),
313       Matches.matchInvSqrt3(),      Matches.matchSqrt3(),
314       Matches.matchPhi(),
315   };
316 
317   Finder->addMatcher(
318       expr(
319           anyOfExhaustive(std::move(ConstantMatchers)),
320           unless(hasParent(explicitCastExpr(hasDestinationType(isFloating())))),
321           hasType(qualType(hasCanonicalTypeUnqualified(
322               anyOf(qualType(asString("float")).bind("float"),
323                     qualType(asString("double")),
324                     qualType(asString("long double")).bind("long double")))))),
325       this);
326 }
327 
check(const MatchFinder::MatchResult & Result)328 void UseStdNumbersCheck::check(const MatchFinder::MatchResult &Result) {
329   /*
330     List of all math constants in the `<numbers>` header
331     + e
332     + log2e
333     + log10e
334     + pi
335     + inv_pi
336     + inv_sqrtpi
337     + ln2
338     + ln10
339     + sqrt2
340     + sqrt3
341     + inv_sqrt3
342     + egamma
343     + phi
344   */
345 
346   // The ordering determines what constants are looked at first.
347   // E.g. look at 'inv_sqrt3' before 'sqrt3' to be able to replace the larger
348   // expression
349   constexpr auto Constants = std::array<std::pair<StringRef, double>, 13>{
350       std::pair{StringRef{"log2e"}, llvm::numbers::log2e},
351       std::pair{StringRef{"log10e"}, llvm::numbers::log10e},
352       std::pair{StringRef{"e"}, llvm::numbers::e},
353       std::pair{StringRef{"egamma"}, llvm::numbers::egamma},
354       std::pair{StringRef{"inv_sqrtpi"}, llvm::numbers::inv_sqrtpi},
355       std::pair{StringRef{"inv_pi"}, llvm::numbers::inv_pi},
356       std::pair{StringRef{"pi"}, llvm::numbers::pi},
357       std::pair{StringRef{"ln2"}, llvm::numbers::ln2},
358       std::pair{StringRef{"ln10"}, llvm::numbers::ln10},
359       std::pair{StringRef{"sqrt2"}, llvm::numbers::sqrt2},
360       std::pair{StringRef{"inv_sqrt3"}, llvm::numbers::inv_sqrt3},
361       std::pair{StringRef{"sqrt3"}, llvm::numbers::sqrt3},
362       std::pair{StringRef{"phi"}, llvm::numbers::phi},
363   };
364 
365   auto MatchedLiterals =
366       llvm::SmallVector<std::tuple<std::string, double, const Expr *>>{};
367 
368   const auto &SM = *Result.SourceManager;
369   const auto &LO = Result.Context->getLangOpts();
370 
371   const auto IsFloat = Result.Nodes.getNodeAs<QualType>("float") != nullptr;
372   const auto IsLongDouble =
373       Result.Nodes.getNodeAs<QualType>("long double") != nullptr;
374 
375   for (const auto &[ConstantName, ConstantValue] : Constants) {
376     const auto *const Match = Result.Nodes.getNodeAs<Expr>(ConstantName);
377     if (Match == nullptr) {
378       continue;
379     }
380 
381     const auto Range = Match->getSourceRange();
382 
383     const auto IsMacro = Range.getBegin().isMacroID();
384 
385     // We do not want to emit a diagnostic when we are matching a macro, but the
386     // match inside of the macro does not cover the whole macro.
387     if (IsMacro && !isRangeOfCompleteMacro(Range, SM, LO)) {
388       continue;
389     }
390 
391     if (const auto PatternBindString = (ConstantName + "_pattern").str();
392         Result.Nodes.getNodeAs<Expr>(PatternBindString) != nullptr) {
393       const auto Code = getCode(ConstantName, IsFloat, IsLongDouble);
394       diag(Range.getBegin(), "prefer '%0' to this %select{formula|macro}1")
395           << Code << IsMacro << FixItHint::CreateReplacement(Range, Code);
396       return;
397     }
398 
399     const auto LiteralBindString = (ConstantName + "_literal").str();
400     if (const auto *const Literal =
401             Result.Nodes.getNodeAs<FloatingLiteral>(LiteralBindString)) {
402       MatchedLiterals.emplace_back(
403           ConstantName,
404           std::abs(Literal->getValueAsApproximateDouble() - ConstantValue),
405           Match);
406     }
407   }
408 
409   // We may have had no matches with literals, but a match with a pattern that
410   // was a part of a macro which was therefore skipped.
411   if (MatchedLiterals.empty()) {
412     return;
413   }
414 
415   llvm::sort(MatchedLiterals, [](const auto &LHS, const auto &RHS) {
416     return std::get<1>(LHS) < std::get<1>(RHS);
417   });
418 
419   const auto &[Constant, Diff, Node] = MatchedLiterals.front();
420 
421   const auto Range = Node->getSourceRange();
422   const auto IsMacro = Range.getBegin().isMacroID();
423 
424   // We do not want to emit a diagnostic when we are matching a macro, but the
425   // match inside of the macro does not cover the whole macro.
426   if (IsMacro && !isRangeOfCompleteMacro(Range, SM, LO)) {
427     return;
428   }
429 
430   const auto Code = getCode(Constant, IsFloat, IsLongDouble);
431   diag(Range.getBegin(),
432        "prefer '%0' to this %select{literal|macro}1, differs by '%2'")
433       << Code << IsMacro << llvm::formatv("{0:e2}", Diff).str()
434       << FixItHint::CreateReplacement(Range, Code)
435       << IncludeInserter.createIncludeInsertion(
436              Result.SourceManager->getFileID(Range.getBegin()), "<numbers>");
437 }
438 
registerPPCallbacks(const SourceManager & SM,Preprocessor * const PP,Preprocessor * const ModuleExpanderPP)439 void UseStdNumbersCheck::registerPPCallbacks(
440     const SourceManager &SM, Preprocessor *const PP,
441     Preprocessor *const ModuleExpanderPP) {
442   IncludeInserter.registerPreprocessor(PP);
443 }
444 
storeOptions(ClangTidyOptions::OptionMap & Opts)445 void UseStdNumbersCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
446   Options.store(Opts, "IncludeStyle", IncludeInserter.getStyle());
447   Options.store(Opts, "DiffThreshold", DiffThresholdString);
448 }
449 } // namespace clang::tidy::modernize
450