xref: /llvm-project/clang/unittests/Sema/CodeCompleteTest.cpp (revision 23ef8bf9c0f338ee073c6c1b553c42e46d2f22ad)
1 //=== unittests/Sema/CodeCompleteTest.cpp - Code Complete tests ==============//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Frontend/CompilerInstance.h"
10 #include "clang/Frontend/FrontendActions.h"
11 #include "clang/Lex/Preprocessor.h"
12 #include "clang/Parse/ParseAST.h"
13 #include "clang/Sema/Sema.h"
14 #include "clang/Sema/SemaDiagnostic.h"
15 #include "clang/Tooling/Tooling.h"
16 #include "llvm/Testing/Annotations/Annotations.h"
17 #include "gmock/gmock.h"
18 #include "gtest/gtest.h"
19 #include <cstddef>
20 #include <string>
21 
22 namespace {
23 
24 using namespace clang;
25 using namespace clang::tooling;
26 using ::testing::AllOf;
27 using ::testing::Contains;
28 using ::testing::Each;
29 using ::testing::UnorderedElementsAre;
30 
31 const char TestCCName[] = "test.cc";
32 
33 struct CompletionContext {
34   std::vector<std::string> VisitedNamespaces;
35   std::string PreferredType;
36   // String representation of std::ptrdiff_t on a given platform. This is a hack
37   // to properly account for different configurations of clang.
38   std::string PtrDiffType;
39 };
40 
41 struct CompletedFunctionDecl {
42   std::string Name;
43   bool IsStatic;
44   bool CanBeCall;
45 };
46 MATCHER_P(named, name, "") { return arg.Name == name; }
47 MATCHER_P(isStatic, value, "") { return arg.IsStatic == value; }
48 MATCHER_P(canBeCall, value, "") { return arg.CanBeCall == value; }
49 
50 class SaveCompletedFunctions : public CodeCompleteConsumer {
51 public:
SaveCompletedFunctions(std::vector<CompletedFunctionDecl> & CompletedFuncDecls)52   SaveCompletedFunctions(std::vector<CompletedFunctionDecl> &CompletedFuncDecls)
53       : CodeCompleteConsumer(/*CodeCompleteOpts=*/{}),
54         CompletedFuncDecls(CompletedFuncDecls),
55         CCTUInfo(std::make_shared<GlobalCodeCompletionAllocator>()) {}
56 
ProcessCodeCompleteResults(Sema & S,CodeCompletionContext Context,CodeCompletionResult * Results,unsigned NumResults)57   void ProcessCodeCompleteResults(Sema &S, CodeCompletionContext Context,
58                                   CodeCompletionResult *Results,
59                                   unsigned NumResults) override {
60     for (unsigned I = 0; I < NumResults; ++I) {
61       auto R = Results[I];
62       if (R.Kind == CodeCompletionResult::RK_Declaration) {
63         auto *ND = R.getDeclaration();
64         if (auto *Template = llvm::dyn_cast<FunctionTemplateDecl>(ND))
65           ND = Template->getTemplatedDecl();
66         if (const auto *FD = llvm::dyn_cast<FunctionDecl>(ND)) {
67           CompletedFunctionDecl D;
68           D.Name = FD->getNameAsString();
69           D.CanBeCall = R.FunctionCanBeCall;
70           D.IsStatic = FD->isStatic();
71           CompletedFuncDecls.emplace_back(std::move(D));
72         }
73       }
74     }
75   }
76 
77 private:
getAllocator()78   CodeCompletionAllocator &getAllocator() override {
79     return CCTUInfo.getAllocator();
80   }
81 
getCodeCompletionTUInfo()82   CodeCompletionTUInfo &getCodeCompletionTUInfo() override { return CCTUInfo; }
83 
84   std::vector<CompletedFunctionDecl> &CompletedFuncDecls;
85 
86   CodeCompletionTUInfo CCTUInfo;
87 };
88 
89 class VisitedContextFinder : public CodeCompleteConsumer {
90 public:
VisitedContextFinder(CompletionContext & ResultCtx)91   VisitedContextFinder(CompletionContext &ResultCtx)
92       : CodeCompleteConsumer(/*CodeCompleteOpts=*/{}), ResultCtx(ResultCtx),
93         CCTUInfo(std::make_shared<GlobalCodeCompletionAllocator>()) {}
94 
ProcessCodeCompleteResults(Sema & S,CodeCompletionContext Context,CodeCompletionResult * Results,unsigned NumResults)95   void ProcessCodeCompleteResults(Sema &S, CodeCompletionContext Context,
96                                   CodeCompletionResult *Results,
97                                   unsigned NumResults) override {
98     ResultCtx.VisitedNamespaces =
99         getVisitedNamespace(Context.getVisitedContexts());
100     ResultCtx.PreferredType = Context.getPreferredType().getAsString();
101     ResultCtx.PtrDiffType =
102         S.getASTContext().getPointerDiffType().getAsString();
103   }
104 
getAllocator()105   CodeCompletionAllocator &getAllocator() override {
106     return CCTUInfo.getAllocator();
107   }
108 
getCodeCompletionTUInfo()109   CodeCompletionTUInfo &getCodeCompletionTUInfo() override { return CCTUInfo; }
110 
111 private:
getVisitedNamespace(CodeCompletionContext::VisitedContextSet VisitedContexts) const112   std::vector<std::string> getVisitedNamespace(
113       CodeCompletionContext::VisitedContextSet VisitedContexts) const {
114     std::vector<std::string> NSNames;
115     for (const auto *Context : VisitedContexts)
116       if (const auto *NS = llvm::dyn_cast<NamespaceDecl>(Context))
117         NSNames.push_back(NS->getQualifiedNameAsString());
118     return NSNames;
119   }
120 
121   CompletionContext &ResultCtx;
122   CodeCompletionTUInfo CCTUInfo;
123 };
124 
125 class CodeCompleteAction : public SyntaxOnlyAction {
126 public:
CodeCompleteAction(ParsedSourceLocation P,CodeCompleteConsumer * Consumer)127   CodeCompleteAction(ParsedSourceLocation P, CodeCompleteConsumer *Consumer)
128       : CompletePosition(std::move(P)), Consumer(Consumer) {}
129 
BeginInvocation(CompilerInstance & CI)130   bool BeginInvocation(CompilerInstance &CI) override {
131     CI.getFrontendOpts().CodeCompletionAt = CompletePosition;
132     CI.setCodeCompletionConsumer(Consumer);
133     return true;
134   }
135 
136 private:
137   // 1-based code complete position <Line, Col>;
138   ParsedSourceLocation CompletePosition;
139   CodeCompleteConsumer *Consumer;
140 };
141 
offsetToPosition(llvm::StringRef Code,size_t Offset)142 ParsedSourceLocation offsetToPosition(llvm::StringRef Code, size_t Offset) {
143   Offset = std::min(Code.size(), Offset);
144   StringRef Before = Code.substr(0, Offset);
145   int Lines = Before.count('\n');
146   size_t PrevNL = Before.rfind('\n');
147   size_t StartOfLine = (PrevNL == StringRef::npos) ? 0 : (PrevNL + 1);
148   return {TestCCName, static_cast<unsigned>(Lines + 1),
149           static_cast<unsigned>(Offset - StartOfLine + 1)};
150 }
151 
runCompletion(StringRef Code,size_t Offset)152 CompletionContext runCompletion(StringRef Code, size_t Offset) {
153   CompletionContext ResultCtx;
154   clang::tooling::runToolOnCodeWithArgs(
155       std::make_unique<CodeCompleteAction>(offsetToPosition(Code, Offset),
156                                            new VisitedContextFinder(ResultCtx)),
157       Code, {"-std=c++11"}, TestCCName);
158   return ResultCtx;
159 }
160 
runCodeCompleteOnCode(StringRef AnnotatedCode)161 CompletionContext runCodeCompleteOnCode(StringRef AnnotatedCode) {
162   llvm::Annotations A(AnnotatedCode);
163   return runCompletion(A.code(), A.point());
164 }
165 
166 std::vector<std::string>
collectPreferredTypes(StringRef AnnotatedCode,std::string * PtrDiffType=nullptr)167 collectPreferredTypes(StringRef AnnotatedCode,
168                       std::string *PtrDiffType = nullptr) {
169   llvm::Annotations A(AnnotatedCode);
170   std::vector<std::string> Types;
171   for (size_t Point : A.points()) {
172     auto Results = runCompletion(A.code(), Point);
173     if (PtrDiffType) {
174       assert(PtrDiffType->empty() || *PtrDiffType == Results.PtrDiffType);
175       *PtrDiffType = Results.PtrDiffType;
176     }
177     Types.push_back(Results.PreferredType);
178   }
179   return Types;
180 }
181 
182 std::vector<CompletedFunctionDecl>
CollectCompletedFunctions(StringRef Code,std::size_t Point)183 CollectCompletedFunctions(StringRef Code, std::size_t Point) {
184   std::vector<CompletedFunctionDecl> Result;
185   clang::tooling::runToolOnCodeWithArgs(
186       std::make_unique<CodeCompleteAction>(offsetToPosition(Code, Point),
187                                            new SaveCompletedFunctions(Result)),
188       Code, {"-std=c++11"}, TestCCName);
189   return Result;
190 }
191 
TEST(SemaCodeCompleteTest,FunctionCanBeCall)192 TEST(SemaCodeCompleteTest, FunctionCanBeCall) {
193   llvm::Annotations Code(R"cpp(
194     struct Foo {
195       static int staticMethod();
196       int method() const;
197       template <typename T, typename U, typename V = int>
198       T generic(U, V);
199       template <typename T, int U = 3>
200       static T staticGeneric();
201       Foo() {
202         this->$canBeCall^
203         $canBeCall^
204         Foo::$canBeCall^
205       }
206     };
207 
208     struct Derived : Foo {
209       using Foo::method;
210       using Foo::generic;
211       Derived() {
212         Foo::$canBeCall^
213       }
214     };
215 
216     struct OtherClass {
217       OtherClass() {
218         Foo f;
219         Derived d;
220         f.$canBeCall^
221         ; // Prevent parsing as 'f.f'
222         f.Foo::$canBeCall^
223         &Foo::$cannotBeCall^
224         ;
225         d.Foo::$canBeCall^
226         ;
227         d.Derived::$canBeCall^
228       }
229     };
230 
231     int main() {
232       Foo f;
233       Derived d;
234       f.$canBeCall^
235       ; // Prevent parsing as 'f.f'
236       f.Foo::$canBeCall^
237       &Foo::$cannotBeCall^
238       ;
239       d.Foo::$canBeCall^
240       ;
241       d.Derived::$canBeCall^
242     }
243     )cpp");
244 
245   for (const auto &P : Code.points("canBeCall")) {
246     auto Results = CollectCompletedFunctions(Code.code(), P);
247     EXPECT_THAT(Results, Contains(AllOf(named("method"), isStatic(false),
248                                         canBeCall(true))));
249     EXPECT_THAT(Results, Contains(AllOf(named("generic"), isStatic(false),
250                                         canBeCall(true))));
251   }
252 
253   for (const auto &P : Code.points("cannotBeCall")) {
254     auto Results = CollectCompletedFunctions(Code.code(), P);
255     EXPECT_THAT(Results, Contains(AllOf(named("method"), isStatic(false),
256                                         canBeCall(false))));
257     EXPECT_THAT(Results, Contains(AllOf(named("generic"), isStatic(false),
258                                         canBeCall(false))));
259   }
260 
261   // static method can always be a call
262   for (const auto &P : Code.points()) {
263     auto Results = CollectCompletedFunctions(Code.code(), P);
264     EXPECT_THAT(Results, Contains(AllOf(named("staticMethod"), isStatic(true),
265                                         canBeCall(true))));
266     EXPECT_THAT(Results, Contains(AllOf(named("staticGeneric"), isStatic(true),
267                                         canBeCall(true))));
268   }
269 }
270 
TEST(SemaCodeCompleteTest,VisitedNSForValidQualifiedId)271 TEST(SemaCodeCompleteTest, VisitedNSForValidQualifiedId) {
272   auto VisitedNS = runCodeCompleteOnCode(R"cpp(
273      namespace ns1 {}
274      namespace ns2 {}
275      namespace ns3 {}
276      namespace ns3 { namespace nns3 {} }
277 
278      namespace foo {
279      using namespace ns1;
280      namespace ns4 {} // not visited
281      namespace { using namespace ns2; }
282      inline namespace bar { using namespace ns3::nns3; }
283      } // foo
284      namespace ns { foo::^ }
285   )cpp")
286                        .VisitedNamespaces;
287   EXPECT_THAT(VisitedNS, UnorderedElementsAre("foo", "ns1", "ns2", "ns3::nns3",
288                                               "foo::(anonymous)"));
289 }
290 
TEST(SemaCodeCompleteTest,VisitedNSForInvalidQualifiedId)291 TEST(SemaCodeCompleteTest, VisitedNSForInvalidQualifiedId) {
292   auto VisitedNS = runCodeCompleteOnCode(R"cpp(
293      namespace na {}
294      namespace ns1 {
295      using namespace na;
296      foo::^
297      }
298   )cpp")
299                        .VisitedNamespaces;
300   EXPECT_THAT(VisitedNS, UnorderedElementsAre("ns1", "na"));
301 }
302 
TEST(SemaCodeCompleteTest,VisitedNSWithoutQualifier)303 TEST(SemaCodeCompleteTest, VisitedNSWithoutQualifier) {
304   auto VisitedNS = runCodeCompleteOnCode(R"cpp(
305     namespace n1 {
306     namespace n2 {
307       void f(^) {}
308     }
309     }
310   )cpp")
311                        .VisitedNamespaces;
312   EXPECT_THAT(VisitedNS, UnorderedElementsAre("n1", "n1::n2"));
313 }
314 
TEST(PreferredTypeTest,BinaryExpr)315 TEST(PreferredTypeTest, BinaryExpr) {
316   // Check various operations for arithmetic types.
317   StringRef Code = R"cpp(
318     void test(int x) {
319       x = ^10;
320       x += ^10; x -= ^10; x *= ^10; x /= ^10; x %= ^10;
321       x + ^10; x - ^10; x * ^10; x / ^10; x % ^10;
322     })cpp";
323   EXPECT_THAT(collectPreferredTypes(Code), Each("int"));
324 
325   Code = R"cpp(
326     void test(float x) {
327       x = ^10;
328       x += ^10; x -= ^10; x *= ^10; x /= ^10; x %= ^10;
329       x + ^10; x - ^10; x * ^10; x / ^10; x % ^10;
330     })cpp";
331   EXPECT_THAT(collectPreferredTypes(Code), Each("float"));
332 
333   // Pointer types.
334   Code = R"cpp(
335     void test(int *ptr) {
336       ptr - ^ptr;
337       ptr = ^ptr;
338     })cpp";
339   EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
340 
341   Code = R"cpp(
342     void test(int *ptr) {
343       ptr + ^10;
344       ptr += ^10;
345       ptr -= ^10;
346     })cpp";
347   {
348     std::string PtrDiff;
349     auto Types = collectPreferredTypes(Code, &PtrDiff);
350     EXPECT_THAT(Types, Each(PtrDiff));
351   }
352 
353   // Comparison operators.
354   Code = R"cpp(
355     void test(int i) {
356       i <= ^1; i < ^1; i >= ^1; i > ^1; i == ^1; i != ^1;
357     }
358   )cpp";
359   EXPECT_THAT(collectPreferredTypes(Code), Each("int"));
360 
361   Code = R"cpp(
362     void test(int *ptr) {
363       ptr <= ^ptr; ptr < ^ptr; ptr >= ^ptr; ptr > ^ptr;
364       ptr == ^ptr; ptr != ^ptr;
365     }
366   )cpp";
367   EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
368 
369   // Relational operations.
370   Code = R"cpp(
371     void test(int i, int *ptr) {
372       i && ^1; i || ^1;
373       ptr && ^1; ptr || ^1;
374     }
375   )cpp";
376   EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
377 
378   // Bitwise operations.
379   Code = R"cpp(
380     void test(long long ll) {
381       ll | ^1; ll & ^1;
382     }
383   )cpp";
384   EXPECT_THAT(collectPreferredTypes(Code), Each("long long"));
385 
386   Code = R"cpp(
387     enum A {};
388     void test(A a) {
389       a | ^1; a & ^1;
390     }
391   )cpp";
392   EXPECT_THAT(collectPreferredTypes(Code), Each("A"));
393 
394   Code = R"cpp(
395     enum class A {};
396     void test(A a) {
397       // This is technically illegal with the 'enum class' without overloaded
398       // operators, but we pretend it's fine.
399       a | ^a; a & ^a;
400     }
401   )cpp";
402   EXPECT_THAT(collectPreferredTypes(Code), Each("A"));
403 
404   // Binary shifts.
405   Code = R"cpp(
406     void test(int i, long long ll) {
407       i << ^1; ll << ^1;
408       i <<= ^1; i <<= ^1;
409       i >> ^1; ll >> ^1;
410       i >>= ^1; i >>= ^1;
411     }
412   )cpp";
413   EXPECT_THAT(collectPreferredTypes(Code), Each("int"));
414 
415   // Comma does not provide any useful information.
416   Code = R"cpp(
417     class Cls {};
418     void test(int i, int* ptr, Cls x) {
419       (i, ^i);
420       (ptr, ^ptr);
421       (x, ^x);
422     }
423   )cpp";
424   EXPECT_THAT(collectPreferredTypes(Code), Each("NULL TYPE"));
425 
426   // User-defined types do not take operator overloading into account.
427   // However, they provide heuristics for some common cases.
428   Code = R"cpp(
429     class Cls {};
430     void test(Cls c) {
431       // we assume arithmetic and comparions ops take the same type.
432       c + ^c; c - ^c; c * ^c; c / ^c; c % ^c;
433       c == ^c; c != ^c; c < ^c; c <= ^c; c > ^c; c >= ^c;
434       // same for the assignments.
435       c = ^c; c += ^c; c -= ^c; c *= ^c; c /= ^c; c %= ^c;
436     }
437   )cpp";
438   EXPECT_THAT(collectPreferredTypes(Code), Each("Cls"));
439 
440   Code = R"cpp(
441     class Cls {};
442     void test(Cls c) {
443       // we assume relational ops operate on bools.
444       c && ^c; c || ^c;
445     }
446   )cpp";
447   EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
448 
449   Code = R"cpp(
450     class Cls {};
451     void test(Cls c) {
452       // we make no assumptions about the following operators, since they are
453       // often overloaded with a non-standard meaning.
454       c << ^c; c >> ^c; c | ^c; c & ^c;
455       c <<= ^c; c >>= ^c; c |= ^c; c &= ^c;
456     }
457   )cpp";
458   EXPECT_THAT(collectPreferredTypes(Code), Each("NULL TYPE"));
459 }
460 
TEST(PreferredTypeTest,Members)461 TEST(PreferredTypeTest, Members) {
462   StringRef Code = R"cpp(
463     struct vector {
464       int *begin();
465       vector clone();
466     };
467 
468     void test(int *a) {
469       a = ^vector().^clone().^begin();
470     }
471   )cpp";
472   EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
473 }
474 
TEST(PreferredTypeTest,Conditions)475 TEST(PreferredTypeTest, Conditions) {
476   StringRef Code = R"cpp(
477     struct vector {
478       bool empty();
479     };
480 
481     void test() {
482       if (^vector().^empty()) {}
483       while (^vector().^empty()) {}
484       for (; ^vector().^empty();) {}
485     }
486   )cpp";
487   EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
488 }
489 
TEST(PreferredTypeTest,InitAndAssignment)490 TEST(PreferredTypeTest, InitAndAssignment) {
491   StringRef Code = R"cpp(
492     struct vector {
493       int* begin();
494     };
495 
496     void test() {
497       const int* x = ^vector().^begin();
498       x = ^vector().^begin();
499 
500       if (const int* y = ^vector().^begin()) {}
501     }
502   )cpp";
503   EXPECT_THAT(collectPreferredTypes(Code), Each("const int *"));
504 }
505 
TEST(PreferredTypeTest,UnaryExprs)506 TEST(PreferredTypeTest, UnaryExprs) {
507   StringRef Code = R"cpp(
508     void test(long long a) {
509       a = +^a;
510       a = -^a
511       a = ++^a;
512       a = --^a;
513     }
514   )cpp";
515   EXPECT_THAT(collectPreferredTypes(Code), Each("long long"));
516 
517   Code = R"cpp(
518     void test(int a, int *ptr) {
519       !^a;
520       !^ptr;
521       !!!^a;
522 
523       a = !^a;
524       a = !^ptr;
525       a = !!!^a;
526     }
527   )cpp";
528   EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
529 
530   Code = R"cpp(
531     void test(int a) {
532       const int* x = &^a;
533     }
534   )cpp";
535   EXPECT_THAT(collectPreferredTypes(Code), Each("const int"));
536 
537   Code = R"cpp(
538     void test(int *a) {
539       int x = *^a;
540       int &r = *^a;
541     }
542   )cpp";
543   EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
544 
545   Code = R"cpp(
546     void test(int a) {
547       *^a;
548       &^a;
549     }
550 
551   )cpp";
552 }
553 
TEST(PreferredTypeTest,ParenExpr)554 TEST(PreferredTypeTest, ParenExpr) {
555   StringRef Code = R"cpp(
556     const int *i = ^(^(^(^10)));
557   )cpp";
558   EXPECT_THAT(collectPreferredTypes(Code), Each("const int *"));
559 }
560 
TEST(PreferredTypeTest,FunctionArguments)561 TEST(PreferredTypeTest, FunctionArguments) {
562   StringRef Code = R"cpp(
563     void foo(const int*);
564 
565     void bar(const int*);
566     void bar(const int*, int b);
567 
568     struct vector {
569       const int *data();
570     };
571     void test() {
572       foo(^(^(^(^vec^tor^().^da^ta^()))));
573       bar(^(^(^(^vec^tor^().^da^ta^()))));
574     }
575   )cpp";
576   EXPECT_THAT(collectPreferredTypes(Code), Each("const int *"));
577 
578   Code = R"cpp(
579     void bar(int, volatile double *);
580     void bar(int, volatile double *, int, int);
581 
582     struct vector {
583       double *data();
584     };
585 
586     struct class_members {
587       void bar(int, volatile double *);
588       void bar(int, volatile double *, int, int);
589     };
590     void test() {
591       bar(10, ^(^(^(^vec^tor^().^da^ta^()))));
592       class_members().bar(10, ^(^(^(^vec^tor^().^da^ta^()))));
593     }
594   )cpp";
595   EXPECT_THAT(collectPreferredTypes(Code), Each("volatile double *"));
596 
597   Code = R"cpp(
598     namespace ns {
599       struct vector {
600       };
601     }
602     void accepts_vector(ns::vector);
603 
604     void test() {
605       accepts_vector(^::^ns::^vector());
606     }
607   )cpp";
608   EXPECT_THAT(collectPreferredTypes(Code), Each("ns::vector"));
609 
610   Code = R"cpp(
611     template <class T>
612     struct vector { using self = vector; };
613 
614     void accepts_vector(vector<int>);
615     int foo(int);
616 
617     void test() {
618       accepts_vector(^::^vector<decltype(foo(1))>::^self);
619     }
620   )cpp";
621   EXPECT_THAT(collectPreferredTypes(Code), Each("vector<int>"));
622 }
623 
TEST(PreferredTypeTest,NoCrashOnInvalidTypes)624 TEST(PreferredTypeTest, NoCrashOnInvalidTypes) {
625   StringRef Code = R"cpp(
626     auto x = decltype(&1)(^);
627     auto y = new decltype(&1)(^);
628     // GNU decimal type extension is not supported in clang.
629     auto z = new _Decimal128(^);
630     void foo() { (void)(foo)(^); }
631   )cpp";
632   EXPECT_THAT(collectPreferredTypes(Code), Each("NULL TYPE"));
633 }
634 
635 } // namespace
636