xref: /llvm-project/clang-tools-extra/clang-reorder-fields/ReorderFieldsAction.cpp (revision fbd86d05fe51d45f19df8d63aee41d979c268f8f)
1 //===-- tools/extra/clang-reorder-fields/ReorderFieldsAction.cpp -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file contains the definition of the
11 /// ReorderFieldsAction::newASTConsumer method
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "ReorderFieldsAction.h"
16 #include "clang/AST/AST.h"
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/Decl.h"
20 #include "clang/AST/RecursiveASTVisitor.h"
21 #include "clang/ASTMatchers/ASTMatchFinder.h"
22 #include "clang/Lex/Lexer.h"
23 #include "clang/Tooling/Refactoring.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include <string>
27 
28 namespace clang {
29 namespace reorder_fields {
30 using namespace clang::ast_matchers;
31 using llvm::SmallSetVector;
32 
33 /// Finds the definition of a record by name.
34 ///
35 /// \returns nullptr if the name is ambiguous or not found.
36 static const RecordDecl *findDefinition(StringRef RecordName,
37                                         ASTContext &Context) {
38   auto Results =
39       match(recordDecl(hasName(RecordName), isDefinition()).bind("recordDecl"),
40             Context);
41   if (Results.empty()) {
42     llvm::errs() << "Definition of " << RecordName << "  not found\n";
43     return nullptr;
44   }
45   if (Results.size() > 1) {
46     llvm::errs() << "The name " << RecordName
47                  << " is ambiguous, several definitions found\n";
48     return nullptr;
49   }
50   return selectFirst<RecordDecl>("recordDecl", Results);
51 }
52 
53 /// Calculates the new order of fields.
54 ///
55 /// \returns empty vector if the list of fields doesn't match the definition.
56 static SmallVector<unsigned, 4>
57 getNewFieldsOrder(const RecordDecl *Definition,
58                   ArrayRef<std::string> DesiredFieldsOrder) {
59   assert(Definition && "Definition is null");
60 
61   llvm::StringMap<unsigned> NameToIndex;
62   for (const auto *Field : Definition->fields())
63     NameToIndex[Field->getName()] = Field->getFieldIndex();
64 
65   if (DesiredFieldsOrder.size() != NameToIndex.size()) {
66     llvm::errs() << "Number of provided fields (" << DesiredFieldsOrder.size()
67                  << ") doesn't match definition (" << NameToIndex.size()
68                  << ").\n";
69     return {};
70   }
71   SmallVector<unsigned, 4> NewFieldsOrder;
72   for (const auto &Name : DesiredFieldsOrder) {
73     if (!NameToIndex.count(Name)) {
74       llvm::errs() << "Field " << Name << " not found in definition.\n";
75       return {};
76     }
77     NewFieldsOrder.push_back(NameToIndex[Name]);
78   }
79   assert(NewFieldsOrder.size() == NameToIndex.size());
80   return NewFieldsOrder;
81 }
82 
83 // FIXME: error-handling
84 /// Replaces one range of source code by another.
85 static void
86 addReplacement(SourceRange Old, SourceRange New, const ASTContext &Context,
87                std::map<std::string, tooling::Replacements> &Replacements) {
88   StringRef NewText =
89       Lexer::getSourceText(CharSourceRange::getTokenRange(New),
90                            Context.getSourceManager(), Context.getLangOpts());
91   tooling::Replacement R(Context.getSourceManager(),
92                          CharSourceRange::getTokenRange(Old), NewText,
93                          Context.getLangOpts());
94   consumeError(Replacements[std::string(R.getFilePath())].add(R));
95 }
96 
97 /// Find all member fields used in the given init-list initializer expr
98 /// that belong to the same record
99 ///
100 /// \returns a set of field declarations, empty if none were present
101 static SmallSetVector<FieldDecl *, 1>
102 findMembersUsedInInitExpr(const CXXCtorInitializer *Initializer,
103                           ASTContext &Context) {
104   SmallSetVector<FieldDecl *, 1> Results;
105   // Note that this does not pick up member fields of base classes since
106   // for those accesses Sema::PerformObjectMemberConversion always inserts an
107   // UncheckedDerivedToBase ImplicitCastExpr between the this expr and the
108   // object expression
109   auto FoundExprs = match(
110       traverse(
111           TK_AsIs,
112           findAll(memberExpr(hasObjectExpression(cxxThisExpr())).bind("ME"))),
113       *Initializer->getInit(), Context);
114   for (BoundNodes &BN : FoundExprs)
115     if (auto *MemExpr = BN.getNodeAs<MemberExpr>("ME"))
116       if (auto *FD = dyn_cast<FieldDecl>(MemExpr->getMemberDecl()))
117         Results.insert(FD);
118   return Results;
119 }
120 
121 /// Returns the start of the leading comments before `Loc`.
122 static SourceLocation getStartOfLeadingComment(SourceLocation Loc,
123                                                const SourceManager &SM,
124                                                const LangOptions &LangOpts) {
125   // We consider any leading comment token that is on the same line or
126   // indented similarly to the first comment to be part of the leading comment.
127   const unsigned Line = SM.getPresumedLineNumber(Loc);
128   const unsigned Column = SM.getPresumedColumnNumber(Loc);
129   std::optional<Token> Tok =
130       Lexer::findPreviousToken(Loc, SM, LangOpts, /*IncludeComments=*/true);
131   while (Tok && Tok->is(tok::comment)) {
132     const SourceLocation CommentLoc =
133         Lexer::GetBeginningOfToken(Tok->getLocation(), SM, LangOpts);
134     if (SM.getPresumedLineNumber(CommentLoc) != Line &&
135         SM.getPresumedColumnNumber(CommentLoc) != Column) {
136       break;
137     }
138     Loc = CommentLoc;
139     Tok = Lexer::findPreviousToken(Loc, SM, LangOpts, /*IncludeComments=*/true);
140   }
141   return Loc;
142 }
143 
144 /// Returns the end of the trailing comments after `Loc`.
145 static SourceLocation getEndOfTrailingComment(SourceLocation Loc,
146                                               const SourceManager &SM,
147                                               const LangOptions &LangOpts) {
148   // We consider any following comment token that is indented more than the
149   // first comment to be part of the trailing comment.
150   const unsigned Column = SM.getPresumedColumnNumber(Loc);
151   std::optional<Token> Tok =
152       Lexer::findNextToken(Loc, SM, LangOpts, /*IncludeComments=*/true);
153   while (Tok && Tok->is(tok::comment) &&
154          SM.getPresumedColumnNumber(Tok->getLocation()) > Column) {
155     Loc = Tok->getEndLoc();
156     Tok = Lexer::findNextToken(Loc, SM, LangOpts, /*IncludeComments=*/true);
157   }
158   return Loc;
159 }
160 
161 /// Returns the full source range for the field declaration up to (including)
162 /// the trailing semicolumn, including potential macro invocations,
163 /// e.g. `int a GUARDED_BY(mu);`. If there is a trailing comment, include it.
164 static SourceRange getFullFieldSourceRange(const FieldDecl &Field,
165                                            const ASTContext &Context) {
166   const SourceRange Range = Field.getSourceRange();
167   SourceLocation Begin = Range.getBegin();
168   SourceLocation End = Range.getEnd();
169   const SourceManager &SM = Context.getSourceManager();
170   const LangOptions &LangOpts = Context.getLangOpts();
171   while (true) {
172     std::optional<Token> CurrentToken = Lexer::findNextToken(End, SM, LangOpts);
173 
174     if (!CurrentToken)
175       return SourceRange(Begin, End);
176 
177     if (CurrentToken->is(tok::eof))
178       return Range; // Something is wrong, return the original range.
179 
180     End = CurrentToken->getLastLoc();
181 
182     if (CurrentToken->is(tok::semi))
183       break;
184   }
185   Begin = getStartOfLeadingComment(Begin, SM, LangOpts);
186   End = getEndOfTrailingComment(End, SM, LangOpts);
187   return SourceRange(Begin, End);
188 }
189 
190 /// Reorders fields in the definition of a struct/class.
191 ///
192 /// At the moment reordering of fields with
193 /// different accesses (public/protected/private) is not supported.
194 /// \returns true on success.
195 static bool reorderFieldsInDefinition(
196     const RecordDecl *Definition, ArrayRef<unsigned> NewFieldsOrder,
197     const ASTContext &Context,
198     std::map<std::string, tooling::Replacements> &Replacements) {
199   assert(Definition && "Definition is null");
200 
201   SmallVector<const FieldDecl *, 10> Fields;
202   for (const auto *Field : Definition->fields())
203     Fields.push_back(Field);
204 
205   // Check that the permutation of the fields doesn't change the accesses
206   for (const auto *Field : Definition->fields()) {
207     const auto FieldIndex = Field->getFieldIndex();
208     if (Field->getAccess() != Fields[NewFieldsOrder[FieldIndex]]->getAccess()) {
209       llvm::errs() << "Currently reordering of fields with different accesses "
210                       "is not supported\n";
211       return false;
212     }
213   }
214 
215   for (const auto *Field : Definition->fields()) {
216     const auto FieldIndex = Field->getFieldIndex();
217     if (FieldIndex == NewFieldsOrder[FieldIndex])
218       continue;
219     addReplacement(
220         getFullFieldSourceRange(*Field, Context),
221         getFullFieldSourceRange(*Fields[NewFieldsOrder[FieldIndex]], Context),
222         Context, Replacements);
223   }
224   return true;
225 }
226 
227 /// Reorders initializers in a C++ struct/class constructor.
228 ///
229 /// A constructor can have initializers for an arbitrary subset of the class's
230 /// fields. Thus, we need to ensure that we reorder just the initializers that
231 /// are present.
232 static void reorderFieldsInConstructor(
233     const CXXConstructorDecl *CtorDecl, ArrayRef<unsigned> NewFieldsOrder,
234     ASTContext &Context,
235     std::map<std::string, tooling::Replacements> &Replacements) {
236   assert(CtorDecl && "Constructor declaration is null");
237   if (CtorDecl->isImplicit() || CtorDecl->getNumCtorInitializers() <= 1)
238     return;
239 
240   // The method FunctionDecl::isThisDeclarationADefinition returns false
241   // for a defaulted function unless that function has been implicitly defined.
242   // Thus this assert needs to be after the previous checks.
243   assert(CtorDecl->isThisDeclarationADefinition() && "Not a definition");
244 
245   SmallVector<unsigned, 10> NewFieldsPositions(NewFieldsOrder.size());
246   for (unsigned i = 0, e = NewFieldsOrder.size(); i < e; ++i)
247     NewFieldsPositions[NewFieldsOrder[i]] = i;
248 
249   SmallVector<const CXXCtorInitializer *, 10> OldWrittenInitializersOrder;
250   SmallVector<const CXXCtorInitializer *, 10> NewWrittenInitializersOrder;
251   for (const auto *Initializer : CtorDecl->inits()) {
252     if (!Initializer->isMemberInitializer() || !Initializer->isWritten())
253       continue;
254 
255     // Warn if this reordering violates initialization expr dependencies.
256     const FieldDecl *ThisM = Initializer->getMember();
257     const auto UsedMembers = findMembersUsedInInitExpr(Initializer, Context);
258     for (const FieldDecl *UM : UsedMembers) {
259       if (NewFieldsPositions[UM->getFieldIndex()] >
260           NewFieldsPositions[ThisM->getFieldIndex()]) {
261         DiagnosticsEngine &DiagEngine = Context.getDiagnostics();
262         auto Description = ("reordering field " + UM->getName() + " after " +
263                             ThisM->getName() + " makes " + UM->getName() +
264                             " uninitialized when used in init expression")
265                                .str();
266         unsigned ID = DiagEngine.getDiagnosticIDs()->getCustomDiagID(
267             DiagnosticIDs::Warning, Description);
268         DiagEngine.Report(Initializer->getSourceLocation(), ID);
269       }
270     }
271 
272     OldWrittenInitializersOrder.push_back(Initializer);
273     NewWrittenInitializersOrder.push_back(Initializer);
274   }
275   auto ByFieldNewPosition = [&](const CXXCtorInitializer *LHS,
276                                 const CXXCtorInitializer *RHS) {
277     assert(LHS && RHS);
278     return NewFieldsPositions[LHS->getMember()->getFieldIndex()] <
279            NewFieldsPositions[RHS->getMember()->getFieldIndex()];
280   };
281   llvm::sort(NewWrittenInitializersOrder, ByFieldNewPosition);
282   assert(OldWrittenInitializersOrder.size() ==
283          NewWrittenInitializersOrder.size());
284   for (unsigned i = 0, e = NewWrittenInitializersOrder.size(); i < e; ++i)
285     if (OldWrittenInitializersOrder[i] != NewWrittenInitializersOrder[i])
286       addReplacement(OldWrittenInitializersOrder[i]->getSourceRange(),
287                      NewWrittenInitializersOrder[i]->getSourceRange(), Context,
288                      Replacements);
289 }
290 
291 /// Reorders initializers in the brace initialization of an aggregate.
292 ///
293 /// At the moment partial initialization is not supported.
294 /// \returns true on success
295 static bool reorderFieldsInInitListExpr(
296     const InitListExpr *InitListEx, ArrayRef<unsigned> NewFieldsOrder,
297     const ASTContext &Context,
298     std::map<std::string, tooling::Replacements> &Replacements) {
299   assert(InitListEx && "Init list expression is null");
300   // We care only about InitListExprs which originate from source code.
301   // Implicit InitListExprs are created by the semantic analyzer.
302   if (!InitListEx->isExplicit())
303     return true;
304   // The method InitListExpr::getSyntacticForm may return nullptr indicating
305   // that the current initializer list also serves as its syntactic form.
306   if (const auto *SyntacticForm = InitListEx->getSyntacticForm())
307     InitListEx = SyntacticForm;
308   // If there are no initializers we do not need to change anything.
309   if (!InitListEx->getNumInits())
310     return true;
311   if (InitListEx->getNumInits() != NewFieldsOrder.size()) {
312     llvm::errs() << "Currently only full initialization is supported\n";
313     return false;
314   }
315   for (unsigned i = 0, e = InitListEx->getNumInits(); i < e; ++i)
316     if (i != NewFieldsOrder[i])
317       addReplacement(InitListEx->getInit(i)->getSourceRange(),
318                      InitListEx->getInit(NewFieldsOrder[i])->getSourceRange(),
319                      Context, Replacements);
320   return true;
321 }
322 
323 namespace {
324 class ReorderingConsumer : public ASTConsumer {
325   StringRef RecordName;
326   ArrayRef<std::string> DesiredFieldsOrder;
327   std::map<std::string, tooling::Replacements> &Replacements;
328 
329 public:
330   ReorderingConsumer(StringRef RecordName,
331                      ArrayRef<std::string> DesiredFieldsOrder,
332                      std::map<std::string, tooling::Replacements> &Replacements)
333       : RecordName(RecordName), DesiredFieldsOrder(DesiredFieldsOrder),
334         Replacements(Replacements) {}
335 
336   ReorderingConsumer(const ReorderingConsumer &) = delete;
337   ReorderingConsumer &operator=(const ReorderingConsumer &) = delete;
338 
339   void HandleTranslationUnit(ASTContext &Context) override {
340     const RecordDecl *RD = findDefinition(RecordName, Context);
341     if (!RD)
342       return;
343     SmallVector<unsigned, 4> NewFieldsOrder =
344         getNewFieldsOrder(RD, DesiredFieldsOrder);
345     if (NewFieldsOrder.empty())
346       return;
347     if (!reorderFieldsInDefinition(RD, NewFieldsOrder, Context, Replacements))
348       return;
349 
350     // CXXRD will be nullptr if C code (not C++) is being processed.
351     const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD);
352     if (CXXRD)
353       for (const auto *C : CXXRD->ctors())
354         if (const auto *D = dyn_cast<CXXConstructorDecl>(C->getDefinition()))
355           reorderFieldsInConstructor(cast<const CXXConstructorDecl>(D),
356                                      NewFieldsOrder, Context, Replacements);
357 
358     // We only need to reorder init list expressions for
359     // plain C structs or C++ aggregate types.
360     // For other types the order of constructor parameters is used,
361     // which we don't change at the moment.
362     // Now (v0) partial initialization is not supported.
363     if (!CXXRD || CXXRD->isAggregate())
364       for (auto Result :
365            match(initListExpr(hasType(equalsNode(RD))).bind("initListExpr"),
366                  Context))
367         if (!reorderFieldsInInitListExpr(
368                 Result.getNodeAs<InitListExpr>("initListExpr"), NewFieldsOrder,
369                 Context, Replacements)) {
370           Replacements.clear();
371           return;
372         }
373   }
374 };
375 } // end anonymous namespace
376 
377 std::unique_ptr<ASTConsumer> ReorderFieldsAction::newASTConsumer() {
378   return std::make_unique<ReorderingConsumer>(RecordName, DesiredFieldsOrder,
379                                                Replacements);
380 }
381 
382 } // namespace reorder_fields
383 } // namespace clang
384