xref: /llvm-project/clang-tools-extra/clang-tidy/modernize/UseStdNumbersCheck.cpp (revision e1fa2fea03ff94627008054267a244744d76b5c2)
1 //===--- UseStdNumbersCheck.cpp - clang_tidy ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX_License_Identifier: Apache_2.0 WITH LLVM_exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UseStdNumbersCheck.h"
10 #include "../ClangTidyDiagnosticConsumer.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Decl.h"
13 #include "clang/AST/Expr.h"
14 #include "clang/AST/Stmt.h"
15 #include "clang/AST/Type.h"
16 #include "clang/ASTMatchers/ASTMatchFinder.h"
17 #include "clang/ASTMatchers/ASTMatchers.h"
18 #include "clang/ASTMatchers/ASTMatchersInternal.h"
19 #include "clang/ASTMatchers/ASTMatchersMacros.h"
20 #include "clang/Basic/Diagnostic.h"
21 #include "clang/Basic/LLVM.h"
22 #include "clang/Basic/LangOptions.h"
23 #include "clang/Basic/SourceLocation.h"
24 #include "clang/Basic/SourceManager.h"
25 #include "clang/Lex/Lexer.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/MathExtras.h"
31 #include <array>
32 #include <cstdint>
33 #include <cstdlib>
34 #include <initializer_list>
35 #include <string>
36 #include <tuple>
37 #include <utility>
38 
39 namespace {
40 using namespace clang::ast_matchers;
41 using clang::ast_matchers::internal::Matcher;
42 using llvm::StringRef;
43 
44 AST_MATCHER_P2(clang::FloatingLiteral, near, double, Value, double,
45                DiffThreshold) {
46   return std::abs(Node.getValueAsApproximateDouble() - Value) < DiffThreshold;
47 }
48 
49 AST_MATCHER_P(clang::QualType, hasCanonicalTypeUnqualified,
50               Matcher<clang::QualType>, InnerMatcher) {
51   return !Node.isNull() &&
52          InnerMatcher.matches(Node->getCanonicalTypeUnqualified(), Finder,
53                               Builder);
54 }
55 
56 AST_MATCHER(clang::QualType, isArithmetic) {
57   return !Node.isNull() && Node->isArithmeticType();
58 }
59 AST_MATCHER(clang::QualType, isFloating) {
60   return !Node.isNull() && Node->isFloatingType();
61 }
62 
63 AST_MATCHER_P(clang::Expr, anyOfExhaustive, std::vector<Matcher<clang::Stmt>>,
64               Exprs) {
65   bool FoundMatch = false;
66   for (const auto &InnerMatcher : Exprs) {
67     clang::ast_matchers::internal::BoundNodesTreeBuilder Result = *Builder;
68     if (InnerMatcher.matches(Node, Finder, &Result)) {
69       *Builder = std::move(Result);
70       FoundMatch = true;
71     }
72   }
73   return FoundMatch;
74 }
75 
76 // Using this struct to store the 'DiffThreshold' config value to create the
77 // matchers without the need to pass 'DiffThreshold' into every matcher.
78 // 'DiffThreshold' is needed in the 'near' matcher, which is used for matching
79 // the literal of every constant and for formulas' subexpressions that look at
80 // literals.
81 struct MatchBuilder {
82   auto
83   ignoreParenAndArithmeticCasting(const Matcher<clang::Expr> Matcher) const {
84     return expr(hasType(qualType(isArithmetic())), ignoringParenCasts(Matcher));
85   }
86 
87   auto ignoreParenAndFloatingCasting(const Matcher<clang::Expr> Matcher) const {
88     return expr(hasType(qualType(isFloating())), ignoringParenCasts(Matcher));
89   }
90 
91   auto matchMathCall(const StringRef FunctionName,
92                      const Matcher<clang::Expr> ArgumentMatcher) const {
93     return expr(ignoreParenAndFloatingCasting(
94         callExpr(callee(functionDecl(hasName(FunctionName),
95                                      hasParameter(0, hasType(isArithmetic())))),
96                  hasArgument(0, ArgumentMatcher))));
97   }
98 
99   auto matchSqrt(const Matcher<clang::Expr> ArgumentMatcher) const {
100     return matchMathCall("sqrt", ArgumentMatcher);
101   }
102 
103   // Used for top-level matchers (i.e. the match that replaces Val with its
104   // constant).
105   //
106   // E.g. The matcher of `std::numbers::pi` uses this matcher to look for
107   // floatLiterals that have the value of pi.
108   //
109   // If the match is for a top-level match, we only care about the literal.
110   auto matchFloatLiteralNear(const StringRef Constant, const double Val) const {
111     return expr(ignoreParenAndFloatingCasting(
112         floatLiteral(near(Val, DiffThreshold)).bind(Constant)));
113   }
114 
115   // Used for non-top-level matchers (i.e. matchers that are used as inner
116   // matchers for top-level matchers).
117   //
118   // E.g.: The matcher of `std::numbers::log2e` uses this matcher to check if
119   // `e` of `log2(e)` is declared constant and initialized with the value for
120   // eulers number.
121   //
122   // Here, we do care about literals and about DeclRefExprs to variable
123   // declarations that are constant and initialized with `Val`. This allows
124   // top-level matchers to see through declared constants for their inner
125   // matches like the `std::numbers::log2e` matcher.
126   auto matchFloatValueNear(const double Val) const {
127     const auto Float = floatLiteral(near(Val, DiffThreshold));
128 
129     const auto Dref = declRefExpr(
130         to(varDecl(hasType(qualType(isConstQualified(), isFloating())),
131                    hasInitializer(ignoreParenAndFloatingCasting(Float)))));
132     return expr(ignoreParenAndFloatingCasting(anyOf(Float, Dref)));
133   }
134 
135   auto matchValue(const int64_t ValInt) const {
136     const auto Int =
137         expr(ignoreParenAndArithmeticCasting(integerLiteral(equals(ValInt))));
138     const auto Float = expr(ignoreParenAndFloatingCasting(
139         matchFloatValueNear(static_cast<double>(ValInt))));
140     const auto Dref = declRefExpr(to(varDecl(
141         hasType(qualType(isConstQualified(), isArithmetic())),
142         hasInitializer(expr(anyOf(ignoringImplicit(Int),
143                                   ignoreParenAndFloatingCasting(Float)))))));
144     return expr(anyOf(Int, Float, Dref));
145   }
146 
147   auto match1Div(const Matcher<clang::Expr> Match) const {
148     return binaryOperator(hasOperatorName("/"), hasLHS(matchValue(1)),
149                           hasRHS(Match));
150   }
151 
152   auto matchEuler() const {
153     return expr(anyOf(matchFloatValueNear(llvm::numbers::e),
154                       matchMathCall("exp", matchValue(1))));
155   }
156   auto matchEulerTopLevel() const {
157     return expr(anyOf(matchFloatLiteralNear("e_literal", llvm::numbers::e),
158                       matchMathCall("exp", matchValue(1)).bind("e_pattern")))
159         .bind("e");
160   }
161 
162   auto matchLog2Euler() const {
163     return expr(
164                anyOf(
165                    matchFloatLiteralNear("log2e_literal", llvm::numbers::log2e),
166                    matchMathCall("log2", matchEuler()).bind("log2e_pattern")))
167         .bind("log2e");
168   }
169 
170   auto matchLog10Euler() const {
171     return expr(
172                anyOf(
173                    matchFloatLiteralNear("log10e_literal",
174                                          llvm::numbers::log10e),
175                    matchMathCall("log10", matchEuler()).bind("log10e_pattern")))
176         .bind("log10e");
177   }
178 
179   auto matchPi() const { return matchFloatValueNear(llvm::numbers::pi); }
180   auto matchPiTopLevel() const {
181     return matchFloatLiteralNear("pi_literal", llvm::numbers::pi).bind("pi");
182   }
183 
184   auto matchEgamma() const {
185     return matchFloatLiteralNear("egamma_literal", llvm::numbers::egamma)
186         .bind("egamma");
187   }
188 
189   auto matchInvPi() const {
190     return expr(anyOf(matchFloatLiteralNear("inv_pi_literal",
191                                             llvm::numbers::inv_pi),
192                       match1Div(matchPi()).bind("inv_pi_pattern")))
193         .bind("inv_pi");
194   }
195 
196   auto matchInvSqrtPi() const {
197     return expr(anyOf(
198                     matchFloatLiteralNear("inv_sqrtpi_literal",
199                                           llvm::numbers::inv_sqrtpi),
200                     match1Div(matchSqrt(matchPi())).bind("inv_sqrtpi_pattern")))
201         .bind("inv_sqrtpi");
202   }
203 
204   auto matchLn2() const {
205     return expr(anyOf(matchFloatLiteralNear("ln2_literal", llvm::numbers::ln2),
206                       matchMathCall("log", matchValue(2)).bind("ln2_pattern")))
207         .bind("ln2");
208   }
209 
210   auto machterLn10() const {
211     return expr(
212                anyOf(matchFloatLiteralNear("ln10_literal", llvm::numbers::ln10),
213                      matchMathCall("log", matchValue(10)).bind("ln10_pattern")))
214         .bind("ln10");
215   }
216 
217   auto matchSqrt2() const {
218     return expr(anyOf(matchFloatLiteralNear("sqrt2_literal",
219                                             llvm::numbers::sqrt2),
220                       matchSqrt(matchValue(2)).bind("sqrt2_pattern")))
221         .bind("sqrt2");
222   }
223 
224   auto matchSqrt3() const {
225     return expr(anyOf(matchFloatLiteralNear("sqrt3_literal",
226                                             llvm::numbers::sqrt3),
227                       matchSqrt(matchValue(3)).bind("sqrt3_pattern")))
228         .bind("sqrt3");
229   }
230 
231   auto matchInvSqrt3() const {
232     return expr(anyOf(matchFloatLiteralNear("inv_sqrt3_literal",
233                                             llvm::numbers::inv_sqrt3),
234                       match1Div(matchSqrt(matchValue(3)))
235                           .bind("inv_sqrt3_pattern")))
236         .bind("inv_sqrt3");
237   }
238 
239   auto matchPhi() const {
240     const auto PhiFormula = binaryOperator(
241         hasOperatorName("/"),
242         hasLHS(binaryOperator(
243             hasOperatorName("+"), hasEitherOperand(matchValue(1)),
244             hasEitherOperand(matchMathCall("sqrt", matchValue(5))))),
245         hasRHS(matchValue(2)));
246     return expr(anyOf(PhiFormula.bind("phi_pattern"),
247                       matchFloatLiteralNear("phi_literal", llvm::numbers::phi)))
248         .bind("phi");
249   }
250 
251   double DiffThreshold;
252 };
253 
254 std::string getCode(const StringRef Constant, const bool IsFloat,
255                     const bool IsLongDouble) {
256   if (IsFloat) {
257     return ("std::numbers::" + Constant + "_v<float>").str();
258   }
259   if (IsLongDouble) {
260     return ("std::numbers::" + Constant + "_v<long double>").str();
261   }
262   return ("std::numbers::" + Constant).str();
263 }
264 
265 bool isRangeOfCompleteMacro(const clang::SourceRange &Range,
266                             const clang::SourceManager &SM,
267                             const clang::LangOptions &LO) {
268   if (!Range.getBegin().isMacroID()) {
269     return false;
270   }
271   if (!clang::Lexer::isAtStartOfMacroExpansion(Range.getBegin(), SM, LO)) {
272     return false;
273   }
274 
275   if (!Range.getEnd().isMacroID()) {
276     return false;
277   }
278 
279   if (!clang::Lexer::isAtEndOfMacroExpansion(Range.getEnd(), SM, LO)) {
280     return false;
281   }
282 
283   return true;
284 }
285 
286 } // namespace
287 
288 namespace clang::tidy::modernize {
289 UseStdNumbersCheck::UseStdNumbersCheck(const StringRef Name,
290                                        ClangTidyContext *const Context)
291     : ClangTidyCheck(Name, Context),
292       IncludeInserter(Options.getLocalOrGlobal("IncludeStyle",
293                                                utils::IncludeSorter::IS_LLVM),
294                       areDiagsSelfContained()),
295       DiffThresholdString{Options.get("DiffThreshold", "0.001")} {
296   if (DiffThresholdString.getAsDouble(DiffThreshold)) {
297     configurationDiag(
298         "Invalid DiffThreshold config value: '%0', expected a double")
299         << DiffThresholdString;
300     DiffThreshold = 0.001;
301   }
302 }
303 
304 void UseStdNumbersCheck::registerMatchers(MatchFinder *const Finder) {
305   const auto Matches = MatchBuilder{DiffThreshold};
306   std::vector<Matcher<clang::Stmt>> ConstantMatchers = {
307       Matches.matchLog2Euler(),     Matches.matchLog10Euler(),
308       Matches.matchEulerTopLevel(), Matches.matchEgamma(),
309       Matches.matchInvSqrtPi(),     Matches.matchInvPi(),
310       Matches.matchPiTopLevel(),    Matches.matchLn2(),
311       Matches.machterLn10(),        Matches.matchSqrt2(),
312       Matches.matchInvSqrt3(),      Matches.matchSqrt3(),
313       Matches.matchPhi(),
314   };
315 
316   Finder->addMatcher(
317       expr(
318           anyOfExhaustive(std::move(ConstantMatchers)),
319           unless(hasParent(explicitCastExpr(hasDestinationType(isFloating())))),
320           hasType(qualType(hasCanonicalTypeUnqualified(
321               anyOf(qualType(asString("float")).bind("float"),
322                     qualType(asString("double")),
323                     qualType(asString("long double")).bind("long double")))))),
324       this);
325 }
326 
327 void UseStdNumbersCheck::check(const MatchFinder::MatchResult &Result) {
328   /*
329     List of all math constants in the `<numbers>` header
330     + e
331     + log2e
332     + log10e
333     + pi
334     + inv_pi
335     + inv_sqrtpi
336     + ln2
337     + ln10
338     + sqrt2
339     + sqrt3
340     + inv_sqrt3
341     + egamma
342     + phi
343   */
344 
345   // The ordering determines what constants are looked at first.
346   // E.g. look at 'inv_sqrt3' before 'sqrt3' to be able to replace the larger
347   // expression
348   constexpr auto Constants = std::array<std::pair<StringRef, double>, 13>{
349       std::pair{StringRef{"log2e"}, llvm::numbers::log2e},
350       std::pair{StringRef{"log10e"}, llvm::numbers::log10e},
351       std::pair{StringRef{"e"}, llvm::numbers::e},
352       std::pair{StringRef{"egamma"}, llvm::numbers::egamma},
353       std::pair{StringRef{"inv_sqrtpi"}, llvm::numbers::inv_sqrtpi},
354       std::pair{StringRef{"inv_pi"}, llvm::numbers::inv_pi},
355       std::pair{StringRef{"pi"}, llvm::numbers::pi},
356       std::pair{StringRef{"ln2"}, llvm::numbers::ln2},
357       std::pair{StringRef{"ln10"}, llvm::numbers::ln10},
358       std::pair{StringRef{"sqrt2"}, llvm::numbers::sqrt2},
359       std::pair{StringRef{"inv_sqrt3"}, llvm::numbers::inv_sqrt3},
360       std::pair{StringRef{"sqrt3"}, llvm::numbers::sqrt3},
361       std::pair{StringRef{"phi"}, llvm::numbers::phi},
362   };
363 
364   auto MatchedLiterals =
365       llvm::SmallVector<std::tuple<std::string, double, const Expr *>>{};
366 
367   const auto &SM = *Result.SourceManager;
368   const auto &LO = Result.Context->getLangOpts();
369 
370   const auto IsFloat = Result.Nodes.getNodeAs<QualType>("float") != nullptr;
371   const auto IsLongDouble =
372       Result.Nodes.getNodeAs<QualType>("long double") != nullptr;
373 
374   for (const auto &[ConstantName, ConstantValue] : Constants) {
375     const auto *const Match = Result.Nodes.getNodeAs<Expr>(ConstantName);
376     if (Match == nullptr) {
377       continue;
378     }
379 
380     const auto Range = Match->getSourceRange();
381 
382     const auto IsMacro = Range.getBegin().isMacroID();
383 
384     // We do not want to emit a diagnostic when we are matching a macro, but the
385     // match inside of the macro does not cover the whole macro.
386     if (IsMacro && !isRangeOfCompleteMacro(Range, SM, LO)) {
387       continue;
388     }
389 
390     if (const auto PatternBindString = (ConstantName + "_pattern").str();
391         Result.Nodes.getNodeAs<Expr>(PatternBindString) != nullptr) {
392       const auto Code = getCode(ConstantName, IsFloat, IsLongDouble);
393       diag(Range.getBegin(), "prefer '%0' to this %select{formula|macro}1")
394           << Code << IsMacro << FixItHint::CreateReplacement(Range, Code);
395       return;
396     }
397 
398     const auto LiteralBindString = (ConstantName + "_literal").str();
399     if (const auto *const Literal =
400             Result.Nodes.getNodeAs<FloatingLiteral>(LiteralBindString)) {
401       MatchedLiterals.emplace_back(
402           ConstantName,
403           std::abs(Literal->getValueAsApproximateDouble() - ConstantValue),
404           Match);
405     }
406   }
407 
408   // We may have had no matches with literals, but a match with a pattern that
409   // was a part of a macro which was therefore skipped.
410   if (MatchedLiterals.empty()) {
411     return;
412   }
413 
414   llvm::sort(MatchedLiterals, [](const auto &LHS, const auto &RHS) {
415     return std::get<1>(LHS) < std::get<1>(RHS);
416   });
417 
418   const auto &[Constant, Diff, Node] = MatchedLiterals.front();
419 
420   const auto Range = Node->getSourceRange();
421   const auto IsMacro = Range.getBegin().isMacroID();
422 
423   // We do not want to emit a diagnostic when we are matching a macro, but the
424   // match inside of the macro does not cover the whole macro.
425   if (IsMacro && !isRangeOfCompleteMacro(Range, SM, LO)) {
426     return;
427   }
428 
429   const auto Code = getCode(Constant, IsFloat, IsLongDouble);
430   diag(Range.getBegin(),
431        "prefer '%0' to this %select{literal|macro}1, differs by '%2'")
432       << Code << IsMacro << llvm::formatv("{0:e2}", Diff).str()
433       << FixItHint::CreateReplacement(Range, Code)
434       << IncludeInserter.createIncludeInsertion(
435              Result.SourceManager->getFileID(Range.getBegin()), "<numbers>");
436 }
437 
438 void UseStdNumbersCheck::registerPPCallbacks(
439     const SourceManager &SM, Preprocessor *const PP,
440     Preprocessor *const ModuleExpanderPP) {
441   IncludeInserter.registerPreprocessor(PP);
442 }
443 
444 void UseStdNumbersCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
445   Options.store(Opts, "IncludeStyle", IncludeInserter.getStyle());
446   Options.store(Opts, "DiffThreshold", DiffThresholdString);
447 }
448 } // namespace clang::tidy::modernize
449