xref: /llvm-project/clang-tools-extra/clangd/quality/CompletionModelCodegen.py (revision dd3c26a045c081620375a878159f536758baba6e)
1985deba9SUtkarsh Saxena"""Code generator for Code Completion Model Inference.
2985deba9SUtkarsh Saxena
3985deba9SUtkarsh SaxenaTool runs on the Decision Forest model defined in {model} directory.
4985deba9SUtkarsh SaxenaIt generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp
5985deba9SUtkarsh SaxenaThe generated files defines the Example class named {cpp_class} having all the features as class members.
6985deba9SUtkarsh SaxenaThe generated runtime provides an `Evaluate` function which can be used to score a code completion candidate.
7985deba9SUtkarsh Saxena"""
8985deba9SUtkarsh Saxena
9985deba9SUtkarsh Saxenaimport argparse
10985deba9SUtkarsh Saxenaimport json
11985deba9SUtkarsh Saxenaimport struct
12985deba9SUtkarsh Saxena
13985deba9SUtkarsh Saxena
14985deba9SUtkarsh Saxenaclass CppClass:
15985deba9SUtkarsh Saxena    """Holds class name and names of the enclosing namespaces."""
16985deba9SUtkarsh Saxena
17985deba9SUtkarsh Saxena    def __init__(self, cpp_class):
18985deba9SUtkarsh Saxena        ns_and_class = cpp_class.split("::")
19985deba9SUtkarsh Saxena        self.ns = [ns for ns in ns_and_class[0:-1] if len(ns) > 0]
20985deba9SUtkarsh Saxena        self.name = ns_and_class[-1]
21985deba9SUtkarsh Saxena        if len(self.name) == 0:
22985deba9SUtkarsh Saxena            raise ValueError("Empty class name.")
23985deba9SUtkarsh Saxena
24985deba9SUtkarsh Saxena    def ns_begin(self):
25985deba9SUtkarsh Saxena        """Returns snippet for opening namespace declarations."""
26985deba9SUtkarsh Saxena        open_ns = ["namespace %s {" % ns for ns in self.ns]
27985deba9SUtkarsh Saxena        return "\n".join(open_ns)
28985deba9SUtkarsh Saxena
29985deba9SUtkarsh Saxena    def ns_end(self):
30985deba9SUtkarsh Saxena        """Returns snippet for closing namespace declarations."""
31*dd3c26a0STobias Hieta        close_ns = ["} // namespace %s" % ns for ns in reversed(self.ns)]
32985deba9SUtkarsh Saxena        return "\n".join(close_ns)
33985deba9SUtkarsh Saxena
34985deba9SUtkarsh Saxena
35985deba9SUtkarsh Saxenadef header_guard(filename):
36*dd3c26a0STobias Hieta    """Returns the header guard for the generated header."""
37985deba9SUtkarsh Saxena    return "GENERATED_DECISION_FOREST_MODEL_%s_H" % filename.upper()
38985deba9SUtkarsh Saxena
39985deba9SUtkarsh Saxena
40985deba9SUtkarsh Saxenadef boost_node(n, label, next_label):
4145698ac0SUtkarsh Saxena    """Returns code snippet for a leaf/boost node."""
42*dd3c26a0STobias Hieta    return "%s: return %sf;" % (label, n["score"])
43985deba9SUtkarsh Saxena
44985deba9SUtkarsh Saxena
45985deba9SUtkarsh Saxenadef if_greater_node(n, label, next_label):
46985deba9SUtkarsh Saxena    """Returns code snippet for a if_greater node.
47985deba9SUtkarsh Saxena    Jumps to true_label if the Example feature (NUMBER) is greater than the threshold.
48985deba9SUtkarsh Saxena    Comparing integers is much faster than comparing floats. Assuming floating points
49985deba9SUtkarsh Saxena    are represented as IEEE 754, it order-encodes the floats to integers before comparing them.
50985deba9SUtkarsh Saxena    Control falls through if condition is evaluated to false."""
51985deba9SUtkarsh Saxena    threshold = n["threshold"]
5245698ac0SUtkarsh Saxena    return "%s: if (E.get%s() >= %s /*%s*/) goto %s;" % (
53*dd3c26a0STobias Hieta        label,
54*dd3c26a0STobias Hieta        n["feature"],
55*dd3c26a0STobias Hieta        order_encode(threshold),
56*dd3c26a0STobias Hieta        threshold,
57*dd3c26a0STobias Hieta        next_label,
58*dd3c26a0STobias Hieta    )
59985deba9SUtkarsh Saxena
60985deba9SUtkarsh Saxena
61985deba9SUtkarsh Saxenadef if_member_node(n, label, next_label):
62985deba9SUtkarsh Saxena    """Returns code snippet for a if_member node.
63985deba9SUtkarsh Saxena    Jumps to true_label if the Example feature (ENUM) is present in the set of enum values
64985deba9SUtkarsh Saxena    described in the node.
65985deba9SUtkarsh Saxena    Control falls through if condition is evaluated to false."""
66*dd3c26a0STobias Hieta    members = "|".join(
67*dd3c26a0STobias Hieta        ["BIT(%s_type::%s)" % (n["feature"], member) for member in n["set"]]
68*dd3c26a0STobias Hieta    )
6945698ac0SUtkarsh Saxena    return "%s: if (E.get%s() & (%s)) goto %s;" % (
70*dd3c26a0STobias Hieta        label,
71*dd3c26a0STobias Hieta        n["feature"],
72*dd3c26a0STobias Hieta        members,
73*dd3c26a0STobias Hieta        next_label,
74*dd3c26a0STobias Hieta    )
75985deba9SUtkarsh Saxena
76985deba9SUtkarsh Saxena
77985deba9SUtkarsh Saxenadef node(n, label, next_label):
78985deba9SUtkarsh Saxena    """Returns code snippet for the node."""
79985deba9SUtkarsh Saxena    return {
80*dd3c26a0STobias Hieta        "boost": boost_node,
81*dd3c26a0STobias Hieta        "if_greater": if_greater_node,
82*dd3c26a0STobias Hieta        "if_member": if_member_node,
83*dd3c26a0STobias Hieta    }[n["operation"]](n, label, next_label)
84985deba9SUtkarsh Saxena
85985deba9SUtkarsh Saxena
86985deba9SUtkarsh Saxenadef tree(t, tree_num, node_num):
87985deba9SUtkarsh Saxena    """Returns code for inferencing a Decision Tree.
88985deba9SUtkarsh Saxena    Also returns the size of the decision tree.
89985deba9SUtkarsh Saxena
90985deba9SUtkarsh Saxena    A tree starts with its label `t{tree#}`.
91985deba9SUtkarsh Saxena    A node of the tree starts with label `t{tree#}_n{node#}`.
92985deba9SUtkarsh Saxena
93985deba9SUtkarsh Saxena    The tree contains two types of node: Conditional node and Leaf node.
94985deba9SUtkarsh Saxena    -   Conditional node evaluates a condition. If true, it jumps to the true node/child.
95985deba9SUtkarsh Saxena        Code is generated using pre-order traversal of the tree considering
96985deba9SUtkarsh Saxena        false node as the first child. Therefore the false node is always the
97985deba9SUtkarsh Saxena        immediately next label.
98985deba9SUtkarsh Saxena    -   Leaf node adds the value to the score and jumps to the next tree.
99985deba9SUtkarsh Saxena    """
100985deba9SUtkarsh Saxena    label = "t%d_n%d" % (tree_num, node_num)
101985deba9SUtkarsh Saxena    code = []
102985deba9SUtkarsh Saxena
103985deba9SUtkarsh Saxena    if t["operation"] == "boost":
104985deba9SUtkarsh Saxena        code.append(node(t, label=label, next_label="t%d" % (tree_num + 1)))
105985deba9SUtkarsh Saxena        return code, 1
106985deba9SUtkarsh Saxena
107*dd3c26a0STobias Hieta    false_code, false_size = tree(t["else"], tree_num=tree_num, node_num=node_num + 1)
108985deba9SUtkarsh Saxena
109985deba9SUtkarsh Saxena    true_node_num = node_num + false_size + 1
110985deba9SUtkarsh Saxena    true_label = "t%d_n%d" % (tree_num, true_node_num)
111985deba9SUtkarsh Saxena
112*dd3c26a0STobias Hieta    true_code, true_size = tree(t["then"], tree_num=tree_num, node_num=true_node_num)
113985deba9SUtkarsh Saxena
114985deba9SUtkarsh Saxena    code.append(node(t, label=label, next_label=true_label))
115985deba9SUtkarsh Saxena
116985deba9SUtkarsh Saxena    return code + false_code + true_code, 1 + false_size + true_size
117985deba9SUtkarsh Saxena
118985deba9SUtkarsh Saxena
119985deba9SUtkarsh Saxenadef gen_header_code(features_json, cpp_class, filename):
120985deba9SUtkarsh Saxena    """Returns code for header declaring the inference runtime.
121985deba9SUtkarsh Saxena
122985deba9SUtkarsh Saxena    Declares the Example class named {cpp_class} inside relevant namespaces.
123985deba9SUtkarsh Saxena    The Example class contains all the features as class members. This
124985deba9SUtkarsh Saxena    class can be used to represent a code completion candidate.
125985deba9SUtkarsh Saxena    Provides `float Evaluate()` function which can be used to score the Example.
126985deba9SUtkarsh Saxena    """
127985deba9SUtkarsh Saxena    setters = []
12845698ac0SUtkarsh Saxena    getters = []
129985deba9SUtkarsh Saxena    for f in features_json:
130985deba9SUtkarsh Saxena        feature = f["name"]
13145698ac0SUtkarsh Saxena
132985deba9SUtkarsh Saxena        if f["kind"] == "NUMBER":
133985deba9SUtkarsh Saxena            # Floats are order-encoded to integers for faster comparison.
134985deba9SUtkarsh Saxena            setters.append(
135*dd3c26a0STobias Hieta                "void set%s(float V) { %s = OrderEncode(V); }" % (feature, feature)
136*dd3c26a0STobias Hieta            )
137985deba9SUtkarsh Saxena        elif f["kind"] == "ENUM":
138985deba9SUtkarsh Saxena            setters.append(
139*dd3c26a0STobias Hieta                "void set%s(unsigned V) { %s = 1LL << V; }" % (feature, feature)
140*dd3c26a0STobias Hieta            )
141985deba9SUtkarsh Saxena        else:
142985deba9SUtkarsh Saxena            raise ValueError("Unhandled feature type.", f["kind"])
143985deba9SUtkarsh Saxena
144985deba9SUtkarsh Saxena    # Class members represent all the features of the Example.
14545698ac0SUtkarsh Saxena    class_members = [
146*dd3c26a0STobias Hieta        "uint%d_t %s = 0;" % (64 if f["kind"] == "ENUM" else 32, f["name"])
14745698ac0SUtkarsh Saxena        for f in features_json
14845698ac0SUtkarsh Saxena    ]
14945698ac0SUtkarsh Saxena    getters = [
150bf935a03SUtkarsh Saxena        "LLVM_ATTRIBUTE_ALWAYS_INLINE uint%d_t get%s() const { return %s; }"
151*dd3c26a0STobias Hieta        % (64 if f["kind"] == "ENUM" else 32, f["name"], f["name"])
15245698ac0SUtkarsh Saxena        for f in features_json
15345698ac0SUtkarsh Saxena    ]
154985deba9SUtkarsh Saxena    nline = "\n  "
155985deba9SUtkarsh Saxena    guard = header_guard(filename)
156985deba9SUtkarsh Saxena    return """#ifndef %s
157985deba9SUtkarsh Saxena#define %s
158985deba9SUtkarsh Saxena#include <cstdint>
159a9f63d22SUtkarsh Saxena#include "llvm/Support/Compiler.h"
160985deba9SUtkarsh Saxena
161985deba9SUtkarsh Saxena%s
162985deba9SUtkarsh Saxenaclass %s {
163985deba9SUtkarsh Saxenapublic:
16445698ac0SUtkarsh Saxena  // Setters.
16545698ac0SUtkarsh Saxena  %s
16645698ac0SUtkarsh Saxena
16745698ac0SUtkarsh Saxena  // Getters.
168985deba9SUtkarsh Saxena  %s
169985deba9SUtkarsh Saxena
170985deba9SUtkarsh Saxenaprivate:
171985deba9SUtkarsh Saxena  %s
172985deba9SUtkarsh Saxena
173985deba9SUtkarsh Saxena  // Produces an integer that sorts in the same order as F.
174985deba9SUtkarsh Saxena  // That is: a < b <==> orderEncode(a) < orderEncode(b).
175985deba9SUtkarsh Saxena  static uint32_t OrderEncode(float F);
176985deba9SUtkarsh Saxena};
177985deba9SUtkarsh Saxena
178985deba9SUtkarsh Saxenafloat Evaluate(const %s&);
179985deba9SUtkarsh Saxena%s
180985deba9SUtkarsh Saxena#endif // %s
181*dd3c26a0STobias Hieta""" % (
182*dd3c26a0STobias Hieta        guard,
183*dd3c26a0STobias Hieta        guard,
184*dd3c26a0STobias Hieta        cpp_class.ns_begin(),
185*dd3c26a0STobias Hieta        cpp_class.name,
18645698ac0SUtkarsh Saxena        nline.join(setters),
18745698ac0SUtkarsh Saxena        nline.join(getters),
18845698ac0SUtkarsh Saxena        nline.join(class_members),
189*dd3c26a0STobias Hieta        cpp_class.name,
190*dd3c26a0STobias Hieta        cpp_class.ns_end(),
191*dd3c26a0STobias Hieta        guard,
192*dd3c26a0STobias Hieta    )
193985deba9SUtkarsh Saxena
194985deba9SUtkarsh Saxena
195985deba9SUtkarsh Saxenadef order_encode(v):
196*dd3c26a0STobias Hieta    i = struct.unpack("<I", struct.pack("<f", v))[0]
197985deba9SUtkarsh Saxena    TopBit = 1 << 31
198985deba9SUtkarsh Saxena    # IEEE 754 floats compare like sign-magnitude integers.
199*dd3c26a0STobias Hieta    if i & TopBit:  # Negative float
200985deba9SUtkarsh Saxena        return (1 << 32) - i  # low half of integers, order reversed.
201985deba9SUtkarsh Saxena    return TopBit + i  # top half of integers
202985deba9SUtkarsh Saxena
203985deba9SUtkarsh Saxena
204985deba9SUtkarsh Saxenadef evaluate_func(forest_json, cpp_class):
20545698ac0SUtkarsh Saxena    """Generates evaluation functions for each tree and combines them in
20645698ac0SUtkarsh Saxena    `float Evaluate(const {Example}&)` function. This function can be
20745698ac0SUtkarsh Saxena    used to score an Example."""
20845698ac0SUtkarsh Saxena
20945698ac0SUtkarsh Saxena    code = ""
21045698ac0SUtkarsh Saxena
21145698ac0SUtkarsh Saxena    # Generate evaluation function of each tree.
21245698ac0SUtkarsh Saxena    code += "namespace {\n"
213985deba9SUtkarsh Saxena    tree_num = 0
214985deba9SUtkarsh Saxena    for tree_json in forest_json:
215*dd3c26a0STobias Hieta        code += "LLVM_ATTRIBUTE_NOINLINE float EvaluateTree%d(const %s& E) {\n" % (
216*dd3c26a0STobias Hieta            tree_num,
217*dd3c26a0STobias Hieta            cpp_class.name,
218*dd3c26a0STobias Hieta        )
219*dd3c26a0STobias Hieta        code += (
220*dd3c26a0STobias Hieta            "  " + "\n  ".join(tree(tree_json, tree_num=tree_num, node_num=0)[0]) + "\n"
221*dd3c26a0STobias Hieta        )
22245698ac0SUtkarsh Saxena        code += "}\n\n"
223985deba9SUtkarsh Saxena        tree_num += 1
22445698ac0SUtkarsh Saxena    code += "} // namespace\n\n"
225985deba9SUtkarsh Saxena
22645698ac0SUtkarsh Saxena    # Combine the scores of all trees in the final function.
22745698ac0SUtkarsh Saxena    # MSAN will timeout if these functions are inlined.
22845698ac0SUtkarsh Saxena    code += "float Evaluate(const %s& E) {\n" % cpp_class.name
22945698ac0SUtkarsh Saxena    code += "  float Score = 0;\n"
23045698ac0SUtkarsh Saxena    for tree_num in range(len(forest_json)):
23145698ac0SUtkarsh Saxena        code += "  Score += EvaluateTree%d(E);\n" % tree_num
23245698ac0SUtkarsh Saxena    code += "  return Score;\n"
23345698ac0SUtkarsh Saxena    code += "}\n"
23445698ac0SUtkarsh Saxena
235985deba9SUtkarsh Saxena    return code
236985deba9SUtkarsh Saxena
237985deba9SUtkarsh Saxena
238985deba9SUtkarsh Saxenadef gen_cpp_code(forest_json, features_json, filename, cpp_class):
239985deba9SUtkarsh Saxena    """Generates code for the .cpp file."""
240985deba9SUtkarsh Saxena    # Headers
241985deba9SUtkarsh Saxena    # Required by OrderEncode(float F).
242*dd3c26a0STobias Hieta    angled_include = ["#include <%s>" % h for h in ["cstring", "limits"]]
243985deba9SUtkarsh Saxena
244985deba9SUtkarsh Saxena    # Include generated header.
245*dd3c26a0STobias Hieta    qouted_headers = {filename + ".h", "llvm/ADT/bit.h"}
246985deba9SUtkarsh Saxena    # Headers required by ENUM features used by the model.
247*dd3c26a0STobias Hieta    qouted_headers |= {f["header"] for f in features_json if f["kind"] == "ENUM"}
248985deba9SUtkarsh Saxena    quoted_include = ['#include "%s"' % h for h in sorted(qouted_headers)]
249985deba9SUtkarsh Saxena
250985deba9SUtkarsh Saxena    # using-decl for ENUM features.
251*dd3c26a0STobias Hieta    using_decls = "\n".join(
252*dd3c26a0STobias Hieta        "using %s_type = %s;" % (feature["name"], feature["type"])
253985deba9SUtkarsh Saxena        for feature in features_json
254*dd3c26a0STobias Hieta        if feature["kind"] == "ENUM"
255*dd3c26a0STobias Hieta    )
256985deba9SUtkarsh Saxena    nl = "\n"
257985deba9SUtkarsh Saxena    return """%s
258985deba9SUtkarsh Saxena
259985deba9SUtkarsh Saxena%s
260985deba9SUtkarsh Saxena
261bf935a03SUtkarsh Saxena#define BIT(X) (1LL << X)
262985deba9SUtkarsh Saxena
263985deba9SUtkarsh Saxena%s
264985deba9SUtkarsh Saxena
265985deba9SUtkarsh Saxena%s
266985deba9SUtkarsh Saxena
267985deba9SUtkarsh Saxenauint32_t %s::OrderEncode(float F) {
268985deba9SUtkarsh Saxena  static_assert(std::numeric_limits<float>::is_iec559, "");
269985deba9SUtkarsh Saxena  constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1);
270985deba9SUtkarsh Saxena
271985deba9SUtkarsh Saxena  // Get the bits of the float. Endianness is the same as for integers.
272985deba9SUtkarsh Saxena  uint32_t U = llvm::bit_cast<uint32_t>(F);
273985deba9SUtkarsh Saxena  std::memcpy(&U, &F, sizeof(U));
274985deba9SUtkarsh Saxena  // IEEE 754 floats compare like sign-magnitude integers.
275985deba9SUtkarsh Saxena  if (U & TopBit)    // Negative float.
276985deba9SUtkarsh Saxena    return 0 - U;    // Map onto the low half of integers, order reversed.
277985deba9SUtkarsh Saxena  return U + TopBit; // Positive floats map onto the high half of integers.
278985deba9SUtkarsh Saxena}
279985deba9SUtkarsh Saxena
280985deba9SUtkarsh Saxena%s
281985deba9SUtkarsh Saxena%s
282*dd3c26a0STobias Hieta""" % (
283*dd3c26a0STobias Hieta        nl.join(angled_include),
284*dd3c26a0STobias Hieta        nl.join(quoted_include),
285*dd3c26a0STobias Hieta        cpp_class.ns_begin(),
286*dd3c26a0STobias Hieta        using_decls,
287*dd3c26a0STobias Hieta        cpp_class.name,
288*dd3c26a0STobias Hieta        evaluate_func(forest_json, cpp_class),
289*dd3c26a0STobias Hieta        cpp_class.ns_end(),
290*dd3c26a0STobias Hieta    )
291985deba9SUtkarsh Saxena
292985deba9SUtkarsh Saxena
293985deba9SUtkarsh Saxenadef main():
294*dd3c26a0STobias Hieta    parser = argparse.ArgumentParser("DecisionForestCodegen")
295*dd3c26a0STobias Hieta    parser.add_argument("--filename", help="output file name.")
296*dd3c26a0STobias Hieta    parser.add_argument("--output_dir", help="output directory.")
297*dd3c26a0STobias Hieta    parser.add_argument("--model", help="path to model directory.")
298985deba9SUtkarsh Saxena    parser.add_argument(
299*dd3c26a0STobias Hieta        "--cpp_class",
300*dd3c26a0STobias Hieta        help="The name of the class (which may be a namespace-qualified) created in generated header.",
301985deba9SUtkarsh Saxena    )
302985deba9SUtkarsh Saxena    ns = parser.parse_args()
303985deba9SUtkarsh Saxena
304985deba9SUtkarsh Saxena    output_dir = ns.output_dir
305985deba9SUtkarsh Saxena    filename = ns.filename
306985deba9SUtkarsh Saxena    header_file = "%s/%s.h" % (output_dir, filename)
307985deba9SUtkarsh Saxena    cpp_file = "%s/%s.cpp" % (output_dir, filename)
308985deba9SUtkarsh Saxena    cpp_class = CppClass(cpp_class=ns.cpp_class)
309985deba9SUtkarsh Saxena
310985deba9SUtkarsh Saxena    model_file = "%s/forest.json" % ns.model
311985deba9SUtkarsh Saxena    features_file = "%s/features.json" % ns.model
312985deba9SUtkarsh Saxena
313985deba9SUtkarsh Saxena    with open(features_file) as f:
314985deba9SUtkarsh Saxena        features_json = json.load(f)
315985deba9SUtkarsh Saxena
316985deba9SUtkarsh Saxena    with open(model_file) as m:
317985deba9SUtkarsh Saxena        forest_json = json.load(m)
318985deba9SUtkarsh Saxena
319*dd3c26a0STobias Hieta    with open(cpp_file, "w+t") as output_cc:
320985deba9SUtkarsh Saxena        output_cc.write(
321*dd3c26a0STobias Hieta            gen_cpp_code(
322*dd3c26a0STobias Hieta                forest_json=forest_json,
323985deba9SUtkarsh Saxena                features_json=features_json,
324985deba9SUtkarsh Saxena                filename=filename,
32545698ac0SUtkarsh Saxena                cpp_class=cpp_class,
326*dd3c26a0STobias Hieta            )
327*dd3c26a0STobias Hieta        )
328*dd3c26a0STobias Hieta
329*dd3c26a0STobias Hieta    with open(header_file, "w+t") as output_h:
330*dd3c26a0STobias Hieta        output_h.write(
331*dd3c26a0STobias Hieta            gen_header_code(
332*dd3c26a0STobias Hieta                features_json=features_json, cpp_class=cpp_class, filename=filename
333*dd3c26a0STobias Hieta            )
334*dd3c26a0STobias Hieta        )
335985deba9SUtkarsh Saxena
336985deba9SUtkarsh Saxena
337*dd3c26a0STobias Hietaif __name__ == "__main__":
338985deba9SUtkarsh Saxena    main()
339