xref: /llvm-project/clang-tools-extra/clangd/quality/CompletionModelCodegen.py (revision dd3c26a045c081620375a878159f536758baba6e)
1"""Code generator for Code Completion Model Inference.
2
3Tool runs on the Decision Forest model defined in {model} directory.
4It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp
5The generated files defines the Example class named {cpp_class} having all the features as class members.
6The generated runtime provides an `Evaluate` function which can be used to score a code completion candidate.
7"""
8
9import argparse
10import json
11import struct
12
13
14class CppClass:
15    """Holds class name and names of the enclosing namespaces."""
16
17    def __init__(self, cpp_class):
18        ns_and_class = cpp_class.split("::")
19        self.ns = [ns for ns in ns_and_class[0:-1] if len(ns) > 0]
20        self.name = ns_and_class[-1]
21        if len(self.name) == 0:
22            raise ValueError("Empty class name.")
23
24    def ns_begin(self):
25        """Returns snippet for opening namespace declarations."""
26        open_ns = ["namespace %s {" % ns for ns in self.ns]
27        return "\n".join(open_ns)
28
29    def ns_end(self):
30        """Returns snippet for closing namespace declarations."""
31        close_ns = ["} // namespace %s" % ns for ns in reversed(self.ns)]
32        return "\n".join(close_ns)
33
34
35def header_guard(filename):
36    """Returns the header guard for the generated header."""
37    return "GENERATED_DECISION_FOREST_MODEL_%s_H" % filename.upper()
38
39
40def boost_node(n, label, next_label):
41    """Returns code snippet for a leaf/boost node."""
42    return "%s: return %sf;" % (label, n["score"])
43
44
45def if_greater_node(n, label, next_label):
46    """Returns code snippet for a if_greater node.
47    Jumps to true_label if the Example feature (NUMBER) is greater than the threshold.
48    Comparing integers is much faster than comparing floats. Assuming floating points
49    are represented as IEEE 754, it order-encodes the floats to integers before comparing them.
50    Control falls through if condition is evaluated to false."""
51    threshold = n["threshold"]
52    return "%s: if (E.get%s() >= %s /*%s*/) goto %s;" % (
53        label,
54        n["feature"],
55        order_encode(threshold),
56        threshold,
57        next_label,
58    )
59
60
61def if_member_node(n, label, next_label):
62    """Returns code snippet for a if_member node.
63    Jumps to true_label if the Example feature (ENUM) is present in the set of enum values
64    described in the node.
65    Control falls through if condition is evaluated to false."""
66    members = "|".join(
67        ["BIT(%s_type::%s)" % (n["feature"], member) for member in n["set"]]
68    )
69    return "%s: if (E.get%s() & (%s)) goto %s;" % (
70        label,
71        n["feature"],
72        members,
73        next_label,
74    )
75
76
77def node(n, label, next_label):
78    """Returns code snippet for the node."""
79    return {
80        "boost": boost_node,
81        "if_greater": if_greater_node,
82        "if_member": if_member_node,
83    }[n["operation"]](n, label, next_label)
84
85
86def tree(t, tree_num, node_num):
87    """Returns code for inferencing a Decision Tree.
88    Also returns the size of the decision tree.
89
90    A tree starts with its label `t{tree#}`.
91    A node of the tree starts with label `t{tree#}_n{node#}`.
92
93    The tree contains two types of node: Conditional node and Leaf node.
94    -   Conditional node evaluates a condition. If true, it jumps to the true node/child.
95        Code is generated using pre-order traversal of the tree considering
96        false node as the first child. Therefore the false node is always the
97        immediately next label.
98    -   Leaf node adds the value to the score and jumps to the next tree.
99    """
100    label = "t%d_n%d" % (tree_num, node_num)
101    code = []
102
103    if t["operation"] == "boost":
104        code.append(node(t, label=label, next_label="t%d" % (tree_num + 1)))
105        return code, 1
106
107    false_code, false_size = tree(t["else"], tree_num=tree_num, node_num=node_num + 1)
108
109    true_node_num = node_num + false_size + 1
110    true_label = "t%d_n%d" % (tree_num, true_node_num)
111
112    true_code, true_size = tree(t["then"], tree_num=tree_num, node_num=true_node_num)
113
114    code.append(node(t, label=label, next_label=true_label))
115
116    return code + false_code + true_code, 1 + false_size + true_size
117
118
119def gen_header_code(features_json, cpp_class, filename):
120    """Returns code for header declaring the inference runtime.
121
122    Declares the Example class named {cpp_class} inside relevant namespaces.
123    The Example class contains all the features as class members. This
124    class can be used to represent a code completion candidate.
125    Provides `float Evaluate()` function which can be used to score the Example.
126    """
127    setters = []
128    getters = []
129    for f in features_json:
130        feature = f["name"]
131
132        if f["kind"] == "NUMBER":
133            # Floats are order-encoded to integers for faster comparison.
134            setters.append(
135                "void set%s(float V) { %s = OrderEncode(V); }" % (feature, feature)
136            )
137        elif f["kind"] == "ENUM":
138            setters.append(
139                "void set%s(unsigned V) { %s = 1LL << V; }" % (feature, feature)
140            )
141        else:
142            raise ValueError("Unhandled feature type.", f["kind"])
143
144    # Class members represent all the features of the Example.
145    class_members = [
146        "uint%d_t %s = 0;" % (64 if f["kind"] == "ENUM" else 32, f["name"])
147        for f in features_json
148    ]
149    getters = [
150        "LLVM_ATTRIBUTE_ALWAYS_INLINE uint%d_t get%s() const { return %s; }"
151        % (64 if f["kind"] == "ENUM" else 32, f["name"], f["name"])
152        for f in features_json
153    ]
154    nline = "\n  "
155    guard = header_guard(filename)
156    return """#ifndef %s
157#define %s
158#include <cstdint>
159#include "llvm/Support/Compiler.h"
160
161%s
162class %s {
163public:
164  // Setters.
165  %s
166
167  // Getters.
168  %s
169
170private:
171  %s
172
173  // Produces an integer that sorts in the same order as F.
174  // That is: a < b <==> orderEncode(a) < orderEncode(b).
175  static uint32_t OrderEncode(float F);
176};
177
178float Evaluate(const %s&);
179%s
180#endif // %s
181""" % (
182        guard,
183        guard,
184        cpp_class.ns_begin(),
185        cpp_class.name,
186        nline.join(setters),
187        nline.join(getters),
188        nline.join(class_members),
189        cpp_class.name,
190        cpp_class.ns_end(),
191        guard,
192    )
193
194
195def order_encode(v):
196    i = struct.unpack("<I", struct.pack("<f", v))[0]
197    TopBit = 1 << 31
198    # IEEE 754 floats compare like sign-magnitude integers.
199    if i & TopBit:  # Negative float
200        return (1 << 32) - i  # low half of integers, order reversed.
201    return TopBit + i  # top half of integers
202
203
204def evaluate_func(forest_json, cpp_class):
205    """Generates evaluation functions for each tree and combines them in
206    `float Evaluate(const {Example}&)` function. This function can be
207    used to score an Example."""
208
209    code = ""
210
211    # Generate evaluation function of each tree.
212    code += "namespace {\n"
213    tree_num = 0
214    for tree_json in forest_json:
215        code += "LLVM_ATTRIBUTE_NOINLINE float EvaluateTree%d(const %s& E) {\n" % (
216            tree_num,
217            cpp_class.name,
218        )
219        code += (
220            "  " + "\n  ".join(tree(tree_json, tree_num=tree_num, node_num=0)[0]) + "\n"
221        )
222        code += "}\n\n"
223        tree_num += 1
224    code += "} // namespace\n\n"
225
226    # Combine the scores of all trees in the final function.
227    # MSAN will timeout if these functions are inlined.
228    code += "float Evaluate(const %s& E) {\n" % cpp_class.name
229    code += "  float Score = 0;\n"
230    for tree_num in range(len(forest_json)):
231        code += "  Score += EvaluateTree%d(E);\n" % tree_num
232    code += "  return Score;\n"
233    code += "}\n"
234
235    return code
236
237
238def gen_cpp_code(forest_json, features_json, filename, cpp_class):
239    """Generates code for the .cpp file."""
240    # Headers
241    # Required by OrderEncode(float F).
242    angled_include = ["#include <%s>" % h for h in ["cstring", "limits"]]
243
244    # Include generated header.
245    qouted_headers = {filename + ".h", "llvm/ADT/bit.h"}
246    # Headers required by ENUM features used by the model.
247    qouted_headers |= {f["header"] for f in features_json if f["kind"] == "ENUM"}
248    quoted_include = ['#include "%s"' % h for h in sorted(qouted_headers)]
249
250    # using-decl for ENUM features.
251    using_decls = "\n".join(
252        "using %s_type = %s;" % (feature["name"], feature["type"])
253        for feature in features_json
254        if feature["kind"] == "ENUM"
255    )
256    nl = "\n"
257    return """%s
258
259%s
260
261#define BIT(X) (1LL << X)
262
263%s
264
265%s
266
267uint32_t %s::OrderEncode(float F) {
268  static_assert(std::numeric_limits<float>::is_iec559, "");
269  constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1);
270
271  // Get the bits of the float. Endianness is the same as for integers.
272  uint32_t U = llvm::bit_cast<uint32_t>(F);
273  std::memcpy(&U, &F, sizeof(U));
274  // IEEE 754 floats compare like sign-magnitude integers.
275  if (U & TopBit)    // Negative float.
276    return 0 - U;    // Map onto the low half of integers, order reversed.
277  return U + TopBit; // Positive floats map onto the high half of integers.
278}
279
280%s
281%s
282""" % (
283        nl.join(angled_include),
284        nl.join(quoted_include),
285        cpp_class.ns_begin(),
286        using_decls,
287        cpp_class.name,
288        evaluate_func(forest_json, cpp_class),
289        cpp_class.ns_end(),
290    )
291
292
293def main():
294    parser = argparse.ArgumentParser("DecisionForestCodegen")
295    parser.add_argument("--filename", help="output file name.")
296    parser.add_argument("--output_dir", help="output directory.")
297    parser.add_argument("--model", help="path to model directory.")
298    parser.add_argument(
299        "--cpp_class",
300        help="The name of the class (which may be a namespace-qualified) created in generated header.",
301    )
302    ns = parser.parse_args()
303
304    output_dir = ns.output_dir
305    filename = ns.filename
306    header_file = "%s/%s.h" % (output_dir, filename)
307    cpp_file = "%s/%s.cpp" % (output_dir, filename)
308    cpp_class = CppClass(cpp_class=ns.cpp_class)
309
310    model_file = "%s/forest.json" % ns.model
311    features_file = "%s/features.json" % ns.model
312
313    with open(features_file) as f:
314        features_json = json.load(f)
315
316    with open(model_file) as m:
317        forest_json = json.load(m)
318
319    with open(cpp_file, "w+t") as output_cc:
320        output_cc.write(
321            gen_cpp_code(
322                forest_json=forest_json,
323                features_json=features_json,
324                filename=filename,
325                cpp_class=cpp_class,
326            )
327        )
328
329    with open(header_file, "w+t") as output_h:
330        output_h.write(
331            gen_header_code(
332                features_json=features_json, cpp_class=cpp_class, filename=filename
333            )
334        )
335
336
337if __name__ == "__main__":
338    main()
339