1 //===--- Quality.cpp ---------------------------------------------*- C++-*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Quality.h" 10 #include "AST.h" 11 #include "ASTSignals.h" 12 #include "FileDistance.h" 13 #include "SourceCode.h" 14 #include "index/Symbol.h" 15 #include "clang/AST/ASTContext.h" 16 #include "clang/AST/Decl.h" 17 #include "clang/AST/DeclCXX.h" 18 #include "clang/AST/DeclTemplate.h" 19 #include "clang/AST/DeclVisitor.h" 20 #include "clang/Basic/SourceManager.h" 21 #include "clang/Sema/CodeCompleteConsumer.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/Support/Casting.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include "llvm/Support/MathExtras.h" 26 #include "llvm/Support/raw_ostream.h" 27 #include <algorithm> 28 #include <cmath> 29 #include <optional> 30 31 namespace clang { 32 namespace clangd { 33 34 static bool hasDeclInMainFile(const Decl &D) { 35 auto &SourceMgr = D.getASTContext().getSourceManager(); 36 for (auto *Redecl : D.redecls()) { 37 if (isInsideMainFile(Redecl->getLocation(), SourceMgr)) 38 return true; 39 } 40 return false; 41 } 42 43 static bool hasUsingDeclInMainFile(const CodeCompletionResult &R) { 44 const auto &Context = R.Declaration->getASTContext(); 45 const auto &SourceMgr = Context.getSourceManager(); 46 if (R.ShadowDecl) { 47 if (isInsideMainFile(R.ShadowDecl->getLocation(), SourceMgr)) 48 return true; 49 } 50 return false; 51 } 52 53 static SymbolQualitySignals::SymbolCategory categorize(const NamedDecl &ND) { 54 if (const auto *FD = dyn_cast<FunctionDecl>(&ND)) { 55 if (FD->isOverloadedOperator()) 56 return SymbolQualitySignals::Operator; 57 } 58 class Switch 59 : public ConstDeclVisitor<Switch, SymbolQualitySignals::SymbolCategory> { 60 public: 61 #define MAP(DeclType, Category) \ 62 SymbolQualitySignals::SymbolCategory Visit##DeclType(const DeclType *) { \ 63 return SymbolQualitySignals::Category; \ 64 } 65 MAP(NamespaceDecl, Namespace); 66 MAP(NamespaceAliasDecl, Namespace); 67 MAP(TypeDecl, Type); 68 MAP(TypeAliasTemplateDecl, Type); 69 MAP(ClassTemplateDecl, Type); 70 MAP(CXXConstructorDecl, Constructor); 71 MAP(CXXDestructorDecl, Destructor); 72 MAP(ValueDecl, Variable); 73 MAP(VarTemplateDecl, Variable); 74 MAP(FunctionDecl, Function); 75 MAP(FunctionTemplateDecl, Function); 76 MAP(Decl, Unknown); 77 #undef MAP 78 }; 79 return Switch().Visit(&ND); 80 } 81 82 static SymbolQualitySignals::SymbolCategory 83 categorize(const CodeCompletionResult &R) { 84 if (R.Declaration) 85 return categorize(*R.Declaration); 86 if (R.Kind == CodeCompletionResult::RK_Macro) 87 return SymbolQualitySignals::Macro; 88 // Everything else is a keyword or a pattern. Patterns are mostly keywords 89 // too, except a few which we recognize by cursor kind. 90 switch (R.CursorKind) { 91 case CXCursor_CXXMethod: 92 return SymbolQualitySignals::Function; 93 case CXCursor_ModuleImportDecl: 94 return SymbolQualitySignals::Namespace; 95 case CXCursor_MacroDefinition: 96 return SymbolQualitySignals::Macro; 97 case CXCursor_TypeRef: 98 return SymbolQualitySignals::Type; 99 case CXCursor_MemberRef: 100 return SymbolQualitySignals::Variable; 101 case CXCursor_Constructor: 102 return SymbolQualitySignals::Constructor; 103 default: 104 return SymbolQualitySignals::Keyword; 105 } 106 } 107 108 static SymbolQualitySignals::SymbolCategory 109 categorize(const index::SymbolInfo &D) { 110 switch (D.Kind) { 111 case index::SymbolKind::Namespace: 112 case index::SymbolKind::NamespaceAlias: 113 return SymbolQualitySignals::Namespace; 114 case index::SymbolKind::Macro: 115 return SymbolQualitySignals::Macro; 116 case index::SymbolKind::Enum: 117 case index::SymbolKind::Struct: 118 case index::SymbolKind::Class: 119 case index::SymbolKind::Protocol: 120 case index::SymbolKind::Extension: 121 case index::SymbolKind::Union: 122 case index::SymbolKind::TypeAlias: 123 case index::SymbolKind::TemplateTypeParm: 124 case index::SymbolKind::TemplateTemplateParm: 125 case index::SymbolKind::Concept: 126 return SymbolQualitySignals::Type; 127 case index::SymbolKind::Function: 128 case index::SymbolKind::ClassMethod: 129 case index::SymbolKind::InstanceMethod: 130 case index::SymbolKind::StaticMethod: 131 case index::SymbolKind::InstanceProperty: 132 case index::SymbolKind::ClassProperty: 133 case index::SymbolKind::StaticProperty: 134 case index::SymbolKind::ConversionFunction: 135 return SymbolQualitySignals::Function; 136 case index::SymbolKind::Destructor: 137 return SymbolQualitySignals::Destructor; 138 case index::SymbolKind::Constructor: 139 return SymbolQualitySignals::Constructor; 140 case index::SymbolKind::Variable: 141 case index::SymbolKind::Field: 142 case index::SymbolKind::EnumConstant: 143 case index::SymbolKind::Parameter: 144 case index::SymbolKind::NonTypeTemplateParm: 145 return SymbolQualitySignals::Variable; 146 case index::SymbolKind::Using: 147 case index::SymbolKind::Module: 148 case index::SymbolKind::Unknown: 149 return SymbolQualitySignals::Unknown; 150 } 151 llvm_unreachable("Unknown index::SymbolKind"); 152 } 153 154 static bool isInstanceMember(const NamedDecl *ND) { 155 if (!ND) 156 return false; 157 if (const auto *TP = dyn_cast<FunctionTemplateDecl>(ND)) 158 ND = TP->TemplateDecl::getTemplatedDecl(); 159 if (const auto *CM = dyn_cast<CXXMethodDecl>(ND)) 160 return !CM->isStatic(); 161 return isa<FieldDecl>(ND); // Note that static fields are VarDecl. 162 } 163 164 static bool isInstanceMember(const index::SymbolInfo &D) { 165 switch (D.Kind) { 166 case index::SymbolKind::InstanceMethod: 167 case index::SymbolKind::InstanceProperty: 168 case index::SymbolKind::Field: 169 return true; 170 default: 171 return false; 172 } 173 } 174 175 void SymbolQualitySignals::merge(const CodeCompletionResult &SemaCCResult) { 176 Deprecated |= (SemaCCResult.Availability == CXAvailability_Deprecated); 177 Category = categorize(SemaCCResult); 178 179 if (SemaCCResult.Declaration) { 180 ImplementationDetail |= isImplementationDetail(SemaCCResult.Declaration); 181 if (auto *ID = SemaCCResult.Declaration->getIdentifier()) 182 ReservedName = ReservedName || isReservedName(ID->getName()); 183 } else if (SemaCCResult.Kind == CodeCompletionResult::RK_Macro) 184 ReservedName = 185 ReservedName || isReservedName(SemaCCResult.Macro->getName()); 186 } 187 188 void SymbolQualitySignals::merge(const Symbol &IndexResult) { 189 Deprecated |= (IndexResult.Flags & Symbol::Deprecated); 190 ImplementationDetail |= (IndexResult.Flags & Symbol::ImplementationDetail); 191 References = std::max(IndexResult.References, References); 192 Category = categorize(IndexResult.SymInfo); 193 ReservedName = ReservedName || isReservedName(IndexResult.Name); 194 } 195 196 float SymbolQualitySignals::evaluateHeuristics() const { 197 float Score = 1; 198 199 // This avoids a sharp gradient for tail symbols, and also neatly avoids the 200 // question of whether 0 references means a bad symbol or missing data. 201 if (References >= 10) { 202 // Use a sigmoid style boosting function, which flats out nicely for large 203 // numbers (e.g. 2.58 for 1M references). 204 // The following boosting function is equivalent to: 205 // m = 0.06 206 // f = 12.0 207 // boost = f * sigmoid(m * std::log(References)) - 0.5 * f + 0.59 208 // Sample data points: (10, 1.00), (100, 1.41), (1000, 1.82), 209 // (10K, 2.21), (100K, 2.58), (1M, 2.94) 210 float S = std::pow(References, -0.06); 211 Score *= 6.0 * (1 - S) / (1 + S) + 0.59; 212 } 213 214 if (Deprecated) 215 Score *= 0.1f; 216 if (ReservedName) 217 Score *= 0.1f; 218 if (ImplementationDetail) 219 Score *= 0.2f; 220 221 switch (Category) { 222 case Keyword: // Often relevant, but misses most signals. 223 Score *= 4; // FIXME: important keywords should have specific boosts. 224 break; 225 case Type: 226 case Function: 227 case Variable: 228 Score *= 1.1f; 229 break; 230 case Namespace: 231 Score *= 0.8f; 232 break; 233 case Macro: 234 case Destructor: 235 case Operator: 236 Score *= 0.5f; 237 break; 238 case Constructor: // No boost constructors so they are after class types. 239 case Unknown: 240 break; 241 } 242 243 return Score; 244 } 245 246 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, 247 const SymbolQualitySignals &S) { 248 OS << llvm::formatv("=== Symbol quality: {0}\n", S.evaluateHeuristics()); 249 OS << llvm::formatv("\tReferences: {0}\n", S.References); 250 OS << llvm::formatv("\tDeprecated: {0}\n", S.Deprecated); 251 OS << llvm::formatv("\tReserved name: {0}\n", S.ReservedName); 252 OS << llvm::formatv("\tImplementation detail: {0}\n", S.ImplementationDetail); 253 OS << llvm::formatv("\tCategory: {0}\n", static_cast<int>(S.Category)); 254 return OS; 255 } 256 257 static SymbolRelevanceSignals::AccessibleScope 258 computeScope(const NamedDecl *D) { 259 // Injected "Foo" within the class "Foo" has file scope, not class scope. 260 const DeclContext *DC = D->getDeclContext(); 261 if (auto *R = dyn_cast_or_null<RecordDecl>(D)) 262 if (R->isInjectedClassName()) 263 DC = DC->getParent(); 264 // Class constructor should have the same scope as the class. 265 if (isa<CXXConstructorDecl>(D)) 266 DC = DC->getParent(); 267 bool InClass = false; 268 for (; !DC->isFileContext(); DC = DC->getParent()) { 269 if (DC->isFunctionOrMethod()) 270 return SymbolRelevanceSignals::FunctionScope; 271 InClass = InClass || DC->isRecord(); 272 } 273 if (InClass) 274 return SymbolRelevanceSignals::ClassScope; 275 // ExternalLinkage threshold could be tweaked, e.g. module-visible as global. 276 // Avoid caching linkage if it may change after enclosing code completion. 277 if (hasUnstableLinkage(D) || llvm::to_underlying(D->getLinkageInternal()) < 278 llvm::to_underlying(Linkage::External)) 279 return SymbolRelevanceSignals::FileScope; 280 return SymbolRelevanceSignals::GlobalScope; 281 } 282 283 void SymbolRelevanceSignals::merge(const Symbol &IndexResult) { 284 SymbolURI = IndexResult.CanonicalDeclaration.FileURI; 285 SymbolScope = IndexResult.Scope; 286 IsInstanceMember |= isInstanceMember(IndexResult.SymInfo); 287 if (!(IndexResult.Flags & Symbol::VisibleOutsideFile)) { 288 Scope = AccessibleScope::FileScope; 289 } 290 if (MainFileSignals) { 291 MainFileRefs = 292 std::max(MainFileRefs, 293 MainFileSignals->ReferencedSymbols.lookup(IndexResult.ID)); 294 ScopeRefsInFile = 295 std::max(ScopeRefsInFile, 296 MainFileSignals->RelatedNamespaces.lookup(IndexResult.Scope)); 297 } 298 } 299 300 void SymbolRelevanceSignals::computeASTSignals( 301 const CodeCompletionResult &SemaResult) { 302 if (!MainFileSignals) 303 return; 304 if ((SemaResult.Kind != CodeCompletionResult::RK_Declaration) && 305 (SemaResult.Kind != CodeCompletionResult::RK_Pattern)) 306 return; 307 if (const NamedDecl *ND = SemaResult.getDeclaration()) { 308 if (hasUnstableLinkage(ND)) 309 return; 310 auto ID = getSymbolID(ND); 311 if (!ID) 312 return; 313 MainFileRefs = 314 std::max(MainFileRefs, MainFileSignals->ReferencedSymbols.lookup(ID)); 315 if (const auto *NSD = dyn_cast<NamespaceDecl>(ND->getDeclContext())) { 316 if (NSD->isAnonymousNamespace()) 317 return; 318 std::string Scope = printNamespaceScope(*NSD); 319 if (!Scope.empty()) 320 ScopeRefsInFile = std::max( 321 ScopeRefsInFile, MainFileSignals->RelatedNamespaces.lookup(Scope)); 322 } 323 } 324 } 325 326 void SymbolRelevanceSignals::merge(const CodeCompletionResult &SemaCCResult) { 327 if (SemaCCResult.Availability == CXAvailability_NotAvailable || 328 SemaCCResult.Availability == CXAvailability_NotAccessible) 329 Forbidden = true; 330 331 if (SemaCCResult.Declaration) { 332 SemaSaysInScope = true; 333 // We boost things that have decls in the main file. We give a fixed score 334 // for all other declarations in sema as they are already included in the 335 // translation unit. 336 float DeclProximity = (hasDeclInMainFile(*SemaCCResult.Declaration) || 337 hasUsingDeclInMainFile(SemaCCResult)) 338 ? 1.0 339 : 0.6; 340 SemaFileProximityScore = std::max(DeclProximity, SemaFileProximityScore); 341 IsInstanceMember |= isInstanceMember(SemaCCResult.Declaration); 342 InBaseClass |= SemaCCResult.InBaseClass; 343 } 344 345 computeASTSignals(SemaCCResult); 346 // Declarations are scoped, others (like macros) are assumed global. 347 if (SemaCCResult.Declaration) 348 Scope = std::min(Scope, computeScope(SemaCCResult.Declaration)); 349 350 NeedsFixIts = !SemaCCResult.FixIts.empty(); 351 } 352 353 static float fileProximityScore(unsigned FileDistance) { 354 // Range: [0, 1] 355 // FileDistance = [0, 1, 2, 3, 4, .., FileDistance::Unreachable] 356 // Score = [1, 0.82, 0.67, 0.55, 0.45, .., 0] 357 if (FileDistance == FileDistance::Unreachable) 358 return 0; 359 // Assume approximately default options are used for sensible scoring. 360 return std::exp(FileDistance * -0.4f / FileDistanceOptions().UpCost); 361 } 362 363 static float scopeProximityScore(unsigned ScopeDistance) { 364 // Range: [0.6, 2]. 365 // ScopeDistance = [0, 1, 2, 3, 4, 5, 6, 7, .., FileDistance::Unreachable] 366 // Score = [2.0, 1.55, 1.2, 0.93, 0.72, 0.65, 0.65, 0.65, .., 0.6] 367 if (ScopeDistance == FileDistance::Unreachable) 368 return 0.6f; 369 return std::max(0.65, 2.0 * std::pow(0.6, ScopeDistance / 2.0)); 370 } 371 372 static std::optional<llvm::StringRef> 373 wordMatching(llvm::StringRef Name, const llvm::StringSet<> *ContextWords) { 374 if (ContextWords) 375 for (const auto &Word : ContextWords->keys()) 376 if (Name.contains_insensitive(Word)) 377 return Word; 378 return std::nullopt; 379 } 380 381 SymbolRelevanceSignals::DerivedSignals 382 SymbolRelevanceSignals::calculateDerivedSignals() const { 383 DerivedSignals Derived; 384 Derived.NameMatchesContext = wordMatching(Name, ContextWords).has_value(); 385 Derived.FileProximityDistance = !FileProximityMatch || SymbolURI.empty() 386 ? FileDistance::Unreachable 387 : FileProximityMatch->distance(SymbolURI); 388 if (ScopeProximityMatch) { 389 // For global symbol, the distance is 0. 390 Derived.ScopeProximityDistance = 391 SymbolScope ? ScopeProximityMatch->distance(*SymbolScope) : 0; 392 } 393 return Derived; 394 } 395 396 float SymbolRelevanceSignals::evaluateHeuristics() const { 397 DerivedSignals Derived = calculateDerivedSignals(); 398 float Score = 1; 399 400 if (Forbidden) 401 return 0; 402 403 Score *= NameMatch; 404 405 // File proximity scores are [0,1] and we translate them into a multiplier in 406 // the range from 1 to 3. 407 Score *= 1 + 2 * std::max(fileProximityScore(Derived.FileProximityDistance), 408 SemaFileProximityScore); 409 410 if (ScopeProximityMatch) 411 // Use a constant scope boost for sema results, as scopes of sema results 412 // can be tricky (e.g. class/function scope). Set to the max boost as we 413 // don't load top-level symbols from the preamble and sema results are 414 // always in the accessible scope. 415 Score *= SemaSaysInScope 416 ? 2.0 417 : scopeProximityScore(Derived.ScopeProximityDistance); 418 419 if (Derived.NameMatchesContext) 420 Score *= 1.5; 421 422 // Symbols like local variables may only be referenced within their scope. 423 // Conversely if we're in that scope, it's likely we'll reference them. 424 if (Query == CodeComplete) { 425 // The narrower the scope where a symbol is visible, the more likely it is 426 // to be relevant when it is available. 427 switch (Scope) { 428 case GlobalScope: 429 break; 430 case FileScope: 431 Score *= 1.5f; 432 break; 433 case ClassScope: 434 Score *= 2; 435 break; 436 case FunctionScope: 437 Score *= 4; 438 break; 439 } 440 } else { 441 // For non-completion queries, the wider the scope where a symbol is 442 // visible, the more likely it is to be relevant. 443 switch (Scope) { 444 case GlobalScope: 445 break; 446 case FileScope: 447 Score *= 0.5f; 448 break; 449 default: 450 // TODO: Handle other scopes as we start to use them for index results. 451 break; 452 } 453 } 454 455 if (TypeMatchesPreferred) 456 Score *= 5.0; 457 458 // Penalize non-instance members when they are accessed via a class instance. 459 if (!IsInstanceMember && 460 (Context == CodeCompletionContext::CCC_DotMemberAccess || 461 Context == CodeCompletionContext::CCC_ArrowMemberAccess)) { 462 Score *= 0.2f; 463 } 464 465 if (InBaseClass) 466 Score *= 0.5f; 467 468 // Penalize for FixIts. 469 if (NeedsFixIts) 470 Score *= 0.5f; 471 472 // Use a sigmoid style boosting function similar to `References`, which flats 473 // out nicely for large values. This avoids a sharp gradient for heavily 474 // referenced symbols. Use smaller gradient for ScopeRefsInFile since ideally 475 // MainFileRefs <= ScopeRefsInFile. 476 if (MainFileRefs >= 2) { 477 // E.g.: (2, 1.12), (9, 2.0), (48, 3.0). 478 float S = std::pow(MainFileRefs, -0.11); 479 Score *= 11.0 * (1 - S) / (1 + S) + 0.7; 480 } 481 if (ScopeRefsInFile >= 2) { 482 // E.g.: (2, 1.04), (14, 2.0), (109, 3.0), (400, 3.6). 483 float S = std::pow(ScopeRefsInFile, -0.10); 484 Score *= 10.0 * (1 - S) / (1 + S) + 0.7; 485 } 486 487 return Score; 488 } 489 490 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, 491 const SymbolRelevanceSignals &S) { 492 OS << llvm::formatv("=== Symbol relevance: {0}\n", S.evaluateHeuristics()); 493 OS << llvm::formatv("\tName: {0}\n", S.Name); 494 OS << llvm::formatv("\tName match: {0}\n", S.NameMatch); 495 if (S.ContextWords) 496 OS << llvm::formatv( 497 "\tMatching context word: {0}\n", 498 wordMatching(S.Name, S.ContextWords).value_or("<none>")); 499 OS << llvm::formatv("\tForbidden: {0}\n", S.Forbidden); 500 OS << llvm::formatv("\tNeedsFixIts: {0}\n", S.NeedsFixIts); 501 OS << llvm::formatv("\tIsInstanceMember: {0}\n", S.IsInstanceMember); 502 OS << llvm::formatv("\tInBaseClass: {0}\n", S.InBaseClass); 503 OS << llvm::formatv("\tContext: {0}\n", getCompletionKindString(S.Context)); 504 OS << llvm::formatv("\tQuery type: {0}\n", static_cast<int>(S.Query)); 505 OS << llvm::formatv("\tScope: {0}\n", static_cast<int>(S.Scope)); 506 507 OS << llvm::formatv("\tSymbol URI: {0}\n", S.SymbolURI); 508 OS << llvm::formatv("\tSymbol scope: {0}\n", 509 S.SymbolScope ? *S.SymbolScope : "<None>"); 510 511 SymbolRelevanceSignals::DerivedSignals Derived = S.calculateDerivedSignals(); 512 if (S.FileProximityMatch) { 513 unsigned Score = fileProximityScore(Derived.FileProximityDistance); 514 OS << llvm::formatv("\tIndex URI proximity: {0} (distance={1})\n", Score, 515 Derived.FileProximityDistance); 516 } 517 OS << llvm::formatv("\tSema file proximity: {0}\n", S.SemaFileProximityScore); 518 519 OS << llvm::formatv("\tSema says in scope: {0}\n", S.SemaSaysInScope); 520 if (S.ScopeProximityMatch) 521 OS << llvm::formatv("\tIndex scope boost: {0}\n", 522 scopeProximityScore(Derived.ScopeProximityDistance)); 523 524 OS << llvm::formatv( 525 "\tType matched preferred: {0} (Context type: {1}, Symbol type: {2}\n", 526 S.TypeMatchesPreferred, S.HadContextType, S.HadSymbolType); 527 528 return OS; 529 } 530 531 float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance) { 532 return SymbolQuality * SymbolRelevance; 533 } 534 535 // Produces an integer that sorts in the same order as F. 536 // That is: a < b <==> encodeFloat(a) < encodeFloat(b). 537 static uint32_t encodeFloat(float F) { 538 static_assert(std::numeric_limits<float>::is_iec559); 539 constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1); 540 541 // Get the bits of the float. Endianness is the same as for integers. 542 uint32_t U = llvm::bit_cast<uint32_t>(F); 543 // IEEE 754 floats compare like sign-magnitude integers. 544 if (U & TopBit) // Negative float. 545 return 0 - U; // Map onto the low half of integers, order reversed. 546 return U + TopBit; // Positive floats map onto the high half of integers. 547 } 548 549 std::string sortText(float Score, llvm::StringRef Name) { 550 // We convert -Score to an integer, and hex-encode for readability. 551 // Example: [0.5, "foo"] -> "41000000foo" 552 std::string S; 553 llvm::raw_string_ostream OS(S); 554 llvm::write_hex(OS, encodeFloat(-Score), llvm::HexPrintStyle::Lower, 555 /*Width=*/2 * sizeof(Score)); 556 OS << Name; 557 return S; 558 } 559 560 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, 561 const SignatureQualitySignals &S) { 562 OS << llvm::formatv("=== Signature Quality:\n"); 563 OS << llvm::formatv("\tNumber of parameters: {0}\n", S.NumberOfParameters); 564 OS << llvm::formatv("\tNumber of optional parameters: {0}\n", 565 S.NumberOfOptionalParameters); 566 OS << llvm::formatv("\tKind: {0}\n", S.Kind); 567 return OS; 568 } 569 570 } // namespace clangd 571 } // namespace clang 572