xref: /llvm-project/clang-tools-extra/clangd/DecisionForest.cpp (revision edd5d777e981ab6a4952c14c35f3ead330c4a761)
1 //===--- DecisionForest.cpp --------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Features.inc"
10 
11 #if !CLANGD_DECISION_FOREST
12 #include "Quality.h"
13 #include <cstdlib>
14 
15 namespace clang {
16 namespace clangd {
17 DecisionForestScores
evaluateDecisionForest(const SymbolQualitySignals & Quality,const SymbolRelevanceSignals & Relevance,float Base)18 evaluateDecisionForest(const SymbolQualitySignals &Quality,
19                        const SymbolRelevanceSignals &Relevance, float Base) {
20   llvm::errs() << "Clangd was compiled without decision forest support.\n";
21   std::abort();
22 }
23 
24 } // namespace clangd
25 } // namespace clang
26 
27 #else // !CLANGD_DECISION_FOREST
28 
29 #include "CompletionModel.h"
30 #include "Quality.h"
31 #include <cmath>
32 
33 namespace clang {
34 namespace clangd {
35 
36 DecisionForestScores
evaluateDecisionForest(const SymbolQualitySignals & Quality,const SymbolRelevanceSignals & Relevance,float Base)37 evaluateDecisionForest(const SymbolQualitySignals &Quality,
38                        const SymbolRelevanceSignals &Relevance, float Base) {
39   Example E;
40   E.setIsDeprecated(Quality.Deprecated);
41   E.setIsReservedName(Quality.ReservedName);
42   E.setIsImplementationDetail(Quality.ImplementationDetail);
43   E.setNumReferences(Quality.References);
44   E.setSymbolCategory(Quality.Category);
45 
46   SymbolRelevanceSignals::DerivedSignals Derived =
47       Relevance.calculateDerivedSignals();
48   int NumMatch = 0;
49   if (Relevance.ContextWords) {
50     for (const auto &Word : Relevance.ContextWords->keys()) {
51       if (Relevance.Name.contains_insensitive(Word)) {
52         ++NumMatch;
53       }
54     }
55   }
56   E.setIsNameInContext(NumMatch > 0);
57   E.setNumNameInContext(NumMatch);
58   E.setFractionNameInContext(
59       Relevance.ContextWords && !Relevance.ContextWords->empty()
60           ? NumMatch * 1.0 / Relevance.ContextWords->size()
61           : 0);
62   E.setIsInBaseClass(Relevance.InBaseClass);
63   E.setFileProximityDistanceCost(Derived.FileProximityDistance);
64   E.setSemaFileProximityScore(Relevance.SemaFileProximityScore);
65   E.setSymbolScopeDistanceCost(Derived.ScopeProximityDistance);
66   E.setSemaSaysInScope(Relevance.SemaSaysInScope);
67   E.setScope(Relevance.Scope);
68   E.setContextKind(Relevance.Context);
69   E.setIsInstanceMember(Relevance.IsInstanceMember);
70   E.setHadContextType(Relevance.HadContextType);
71   E.setHadSymbolType(Relevance.HadSymbolType);
72   E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred);
73 
74   DecisionForestScores Scores;
75   // Exponentiating DecisionForest prediction makes the score of each tree a
76   // multiplciative boost (like NameMatch). This allows us to weigh the
77   // prediction score and NameMatch appropriately.
78   Scores.ExcludingName = pow(Base, Evaluate(E));
79   // Following cases are not part of the generated training dataset:
80   //  - Symbols with `NeedsFixIts`.
81   //  - Forbidden symbols.
82   //  - Keywords: Dataset contains only macros and decls.
83   if (Relevance.NeedsFixIts)
84     Scores.ExcludingName *= 0.5;
85   if (Relevance.Forbidden)
86     Scores.ExcludingName *= 0;
87   if (Quality.Category == SymbolQualitySignals::Keyword)
88     Scores.ExcludingName *= 4;
89 
90   // NameMatch should be a multiplier on total score to support rescoring.
91   Scores.Total = Relevance.NameMatch * Scores.ExcludingName;
92   return Scores;
93 }
94 
95 } // namespace clangd
96 } // namespace clang
97 
98 #endif // !CLANGD_DECISION_FOREST
99