xref: /llvm-project/clang-tools-extra/clang-tidy/altera/UnrollLoopsCheck.cpp (revision 03dff0d4acafb61bbb3b765507de79c5e5b6681a)
1 //===--- UnrollLoopsCheck.cpp - clang-tidy --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UnrollLoopsCheck.h"
10 #include "clang/AST/APValue.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/ASTTypeTraits.h"
13 #include "clang/AST/OperationKinds.h"
14 #include "clang/AST/ParentMapContext.h"
15 #include "clang/ASTMatchers/ASTMatchFinder.h"
16 #include <cmath>
17 
18 using namespace clang::ast_matchers;
19 
20 namespace clang::tidy::altera {
21 
UnrollLoopsCheck(StringRef Name,ClangTidyContext * Context)22 UnrollLoopsCheck::UnrollLoopsCheck(StringRef Name, ClangTidyContext *Context)
23     : ClangTidyCheck(Name, Context),
24       MaxLoopIterations(Options.get("MaxLoopIterations", 100U)) {}
25 
registerMatchers(MatchFinder * Finder)26 void UnrollLoopsCheck::registerMatchers(MatchFinder *Finder) {
27   const auto HasLoopBound = hasDescendant(
28       varDecl(matchesName("__end*"),
29               hasDescendant(integerLiteral().bind("cxx_loop_bound"))));
30   const auto CXXForRangeLoop =
31       cxxForRangeStmt(anyOf(HasLoopBound, unless(HasLoopBound)));
32   const auto AnyLoop = anyOf(forStmt(), whileStmt(), doStmt(), CXXForRangeLoop);
33   Finder->addMatcher(
34       stmt(AnyLoop, unless(hasDescendant(stmt(AnyLoop)))).bind("loop"), this);
35 }
36 
check(const MatchFinder::MatchResult & Result)37 void UnrollLoopsCheck::check(const MatchFinder::MatchResult &Result) {
38   const auto *Loop = Result.Nodes.getNodeAs<Stmt>("loop");
39   const auto *CXXLoopBound =
40       Result.Nodes.getNodeAs<IntegerLiteral>("cxx_loop_bound");
41   const ASTContext *Context = Result.Context;
42   switch (unrollType(Loop, Result.Context)) {
43   case NotUnrolled:
44     diag(Loop->getBeginLoc(),
45          "kernel performance could be improved by unrolling this loop with a "
46          "'#pragma unroll' directive");
47     break;
48   case PartiallyUnrolled:
49     // Loop already partially unrolled, do nothing.
50     break;
51   case FullyUnrolled:
52     if (hasKnownBounds(Loop, CXXLoopBound, Context)) {
53       if (hasLargeNumIterations(Loop, CXXLoopBound, Context)) {
54         diag(Loop->getBeginLoc(),
55              "loop likely has a large number of iterations and thus "
56              "cannot be fully unrolled; to partially unroll this loop, use "
57              "the '#pragma unroll <num>' directive");
58         return;
59       }
60       return;
61     }
62     if (isa<WhileStmt, DoStmt>(Loop)) {
63       diag(Loop->getBeginLoc(),
64            "full unrolling requested, but loop bounds may not be known; to "
65            "partially unroll this loop, use the '#pragma unroll <num>' "
66            "directive",
67            DiagnosticIDs::Note);
68       break;
69     }
70     diag(Loop->getBeginLoc(),
71          "full unrolling requested, but loop bounds are not known; to "
72          "partially unroll this loop, use the '#pragma unroll <num>' "
73          "directive");
74     break;
75   }
76 }
77 
78 enum UnrollLoopsCheck::UnrollType
unrollType(const Stmt * Statement,ASTContext * Context)79 UnrollLoopsCheck::unrollType(const Stmt *Statement, ASTContext *Context) {
80   const DynTypedNodeList Parents = Context->getParents<Stmt>(*Statement);
81   for (const DynTypedNode &Parent : Parents) {
82     const auto *ParentStmt = Parent.get<AttributedStmt>();
83     if (!ParentStmt)
84       continue;
85     for (const Attr *Attribute : ParentStmt->getAttrs()) {
86       const auto *LoopHint = dyn_cast<LoopHintAttr>(Attribute);
87       if (!LoopHint)
88         continue;
89       switch (LoopHint->getState()) {
90       case LoopHintAttr::Numeric:
91         return PartiallyUnrolled;
92       case LoopHintAttr::Disable:
93         return NotUnrolled;
94       case LoopHintAttr::Full:
95         return FullyUnrolled;
96       case LoopHintAttr::Enable:
97         return FullyUnrolled;
98       case LoopHintAttr::AssumeSafety:
99         return NotUnrolled;
100       case LoopHintAttr::FixedWidth:
101         return NotUnrolled;
102       case LoopHintAttr::ScalableWidth:
103         return NotUnrolled;
104       }
105     }
106   }
107   return NotUnrolled;
108 }
109 
hasKnownBounds(const Stmt * Statement,const IntegerLiteral * CXXLoopBound,const ASTContext * Context)110 bool UnrollLoopsCheck::hasKnownBounds(const Stmt *Statement,
111                                       const IntegerLiteral *CXXLoopBound,
112                                       const ASTContext *Context) {
113   if (isa<CXXForRangeStmt>(Statement))
114     return CXXLoopBound != nullptr;
115   // Too many possibilities in a while statement, so always recommend partial
116   // unrolling for these.
117   if (isa<WhileStmt, DoStmt>(Statement))
118     return false;
119   // The last loop type is a for loop.
120   const auto *ForLoop = cast<ForStmt>(Statement);
121   const Stmt *Initializer = ForLoop->getInit();
122   const Expr *Conditional = ForLoop->getCond();
123   const Expr *Increment = ForLoop->getInc();
124   if (!Initializer || !Conditional || !Increment)
125     return false;
126   // If the loop variable value isn't known, loop bounds are unknown.
127   if (const auto *InitDeclStatement = dyn_cast<DeclStmt>(Initializer)) {
128     if (const auto *VariableDecl =
129             dyn_cast<VarDecl>(InitDeclStatement->getSingleDecl())) {
130       APValue *Evaluation = VariableDecl->evaluateValue();
131       if (!Evaluation || !Evaluation->hasValue())
132         return false;
133     }
134   }
135   // If increment is unary and not one of ++ and --, loop bounds are unknown.
136   if (const auto *Op = dyn_cast<UnaryOperator>(Increment))
137     if (!Op->isIncrementDecrementOp())
138       return false;
139 
140   if (const auto *BinaryOp = dyn_cast<BinaryOperator>(Conditional)) {
141     const Expr *LHS = BinaryOp->getLHS();
142     const Expr *RHS = BinaryOp->getRHS();
143     // If both sides are value dependent or constant, loop bounds are unknown.
144     return LHS->isEvaluatable(*Context) != RHS->isEvaluatable(*Context);
145   }
146   return false; // If it's not a binary operator, loop bounds are unknown.
147 }
148 
getCondExpr(const Stmt * Statement)149 const Expr *UnrollLoopsCheck::getCondExpr(const Stmt *Statement) {
150   if (const auto *ForLoop = dyn_cast<ForStmt>(Statement))
151     return ForLoop->getCond();
152   if (const auto *WhileLoop = dyn_cast<WhileStmt>(Statement))
153     return WhileLoop->getCond();
154   if (const auto *DoWhileLoop = dyn_cast<DoStmt>(Statement))
155     return DoWhileLoop->getCond();
156   if (const auto *CXXRangeLoop = dyn_cast<CXXForRangeStmt>(Statement))
157     return CXXRangeLoop->getCond();
158   llvm_unreachable("Unknown loop");
159 }
160 
hasLargeNumIterations(const Stmt * Statement,const IntegerLiteral * CXXLoopBound,const ASTContext * Context)161 bool UnrollLoopsCheck::hasLargeNumIterations(const Stmt *Statement,
162                                              const IntegerLiteral *CXXLoopBound,
163                                              const ASTContext *Context) {
164   // Because hasKnownBounds is called before this, if this is true, then
165   // CXXLoopBound is also matched.
166   if (isa<CXXForRangeStmt>(Statement)) {
167     assert(CXXLoopBound && "CXX ranged for loop has no loop bound");
168     return exprHasLargeNumIterations(CXXLoopBound, Context);
169   }
170   const auto *ForLoop = cast<ForStmt>(Statement);
171   const Stmt *Initializer = ForLoop->getInit();
172   const Expr *Conditional = ForLoop->getCond();
173   const Expr *Increment = ForLoop->getInc();
174   int InitValue = 0;
175   // If the loop variable value isn't known, we can't know the loop bounds.
176   if (const auto *InitDeclStatement = dyn_cast<DeclStmt>(Initializer)) {
177     if (const auto *VariableDecl =
178             dyn_cast<VarDecl>(InitDeclStatement->getSingleDecl())) {
179       APValue *Evaluation = VariableDecl->evaluateValue();
180       if (!Evaluation || !Evaluation->isInt())
181         return true;
182       InitValue = Evaluation->getInt().getExtValue();
183     }
184   }
185 
186   int EndValue = 0;
187   const auto *BinaryOp = cast<BinaryOperator>(Conditional);
188   if (!extractValue(EndValue, BinaryOp, Context))
189     return true;
190 
191   double Iterations = 0.0;
192 
193   // If increment is unary and not one of ++, --, we can't know the loop bounds.
194   if (const auto *Op = dyn_cast<UnaryOperator>(Increment)) {
195     if (Op->isIncrementOp())
196       Iterations = EndValue - InitValue;
197     else if (Op->isDecrementOp())
198       Iterations = InitValue - EndValue;
199     else
200       llvm_unreachable("Unary operator neither increment nor decrement");
201   }
202 
203   // If increment is binary and not one of +, -, *, /, we can't know the loop
204   // bounds.
205   if (const auto *Op = dyn_cast<BinaryOperator>(Increment)) {
206     int ConstantValue = 0;
207     if (!extractValue(ConstantValue, Op, Context))
208       return true;
209     switch (Op->getOpcode()) {
210     case (BO_AddAssign):
211       Iterations = ceil(float(EndValue - InitValue) / ConstantValue);
212       break;
213     case (BO_SubAssign):
214       Iterations = ceil(float(InitValue - EndValue) / ConstantValue);
215       break;
216     case (BO_MulAssign):
217       Iterations = 1 + (log((double)EndValue) - log((double)InitValue)) /
218                            log((double)ConstantValue);
219       break;
220     case (BO_DivAssign):
221       Iterations = 1 + (log((double)InitValue) - log((double)EndValue)) /
222                            log((double)ConstantValue);
223       break;
224     default:
225       // All other operators are not handled; assume large bounds.
226       return true;
227     }
228   }
229   return Iterations > MaxLoopIterations;
230 }
231 
extractValue(int & Value,const BinaryOperator * Op,const ASTContext * Context)232 bool UnrollLoopsCheck::extractValue(int &Value, const BinaryOperator *Op,
233                                     const ASTContext *Context) {
234   const Expr *LHS = Op->getLHS();
235   const Expr *RHS = Op->getRHS();
236   Expr::EvalResult Result;
237   if (LHS->isEvaluatable(*Context))
238     LHS->EvaluateAsRValue(Result, *Context);
239   else if (RHS->isEvaluatable(*Context))
240     RHS->EvaluateAsRValue(Result, *Context);
241   else
242     return false; // Cannot evaluate either side.
243   if (!Result.Val.isInt())
244     return false; // Cannot check number of iterations, return false to be
245                   // safe.
246   Value = Result.Val.getInt().getExtValue();
247   return true;
248 }
249 
exprHasLargeNumIterations(const Expr * Expression,const ASTContext * Context) const250 bool UnrollLoopsCheck::exprHasLargeNumIterations(const Expr *Expression,
251                                                  const ASTContext *Context) const {
252   Expr::EvalResult Result;
253   if (Expression->EvaluateAsRValue(Result, *Context)) {
254     if (!Result.Val.isInt())
255       return false; // Cannot check number of iterations, return false to be
256                     // safe.
257     // The following assumes values go from 0 to Val in increments of 1.
258     return Result.Val.getInt() > MaxLoopIterations;
259   }
260   // Cannot evaluate Expression as an r-value, so cannot check number of
261   // iterations.
262   return false;
263 }
264 
storeOptions(ClangTidyOptions::OptionMap & Opts)265 void UnrollLoopsCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
266   Options.store(Opts, "MaxLoopIterations", MaxLoopIterations);
267 }
268 
269 } // namespace clang::tidy::altera
270