1985deba9SUtkarsh Saxena"""Code generator for Code Completion Model Inference. 2985deba9SUtkarsh Saxena 3985deba9SUtkarsh SaxenaTool runs on the Decision Forest model defined in {model} directory. 4985deba9SUtkarsh SaxenaIt generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp 5985deba9SUtkarsh SaxenaThe generated files defines the Example class named {cpp_class} having all the features as class members. 6985deba9SUtkarsh SaxenaThe generated runtime provides an `Evaluate` function which can be used to score a code completion candidate. 7985deba9SUtkarsh Saxena""" 8985deba9SUtkarsh Saxena 9985deba9SUtkarsh Saxenaimport argparse 10985deba9SUtkarsh Saxenaimport json 11985deba9SUtkarsh Saxenaimport struct 12985deba9SUtkarsh Saxena 13985deba9SUtkarsh Saxena 14985deba9SUtkarsh Saxenaclass CppClass: 15985deba9SUtkarsh Saxena """Holds class name and names of the enclosing namespaces.""" 16985deba9SUtkarsh Saxena 17985deba9SUtkarsh Saxena def __init__(self, cpp_class): 18985deba9SUtkarsh Saxena ns_and_class = cpp_class.split("::") 19985deba9SUtkarsh Saxena self.ns = [ns for ns in ns_and_class[0:-1] if len(ns) > 0] 20985deba9SUtkarsh Saxena self.name = ns_and_class[-1] 21985deba9SUtkarsh Saxena if len(self.name) == 0: 22985deba9SUtkarsh Saxena raise ValueError("Empty class name.") 23985deba9SUtkarsh Saxena 24985deba9SUtkarsh Saxena def ns_begin(self): 25985deba9SUtkarsh Saxena """Returns snippet for opening namespace declarations.""" 26985deba9SUtkarsh Saxena open_ns = ["namespace %s {" % ns for ns in self.ns] 27985deba9SUtkarsh Saxena return "\n".join(open_ns) 28985deba9SUtkarsh Saxena 29985deba9SUtkarsh Saxena def ns_end(self): 30985deba9SUtkarsh Saxena """Returns snippet for closing namespace declarations.""" 31*dd3c26a0STobias Hieta close_ns = ["} // namespace %s" % ns for ns in reversed(self.ns)] 32985deba9SUtkarsh Saxena return "\n".join(close_ns) 33985deba9SUtkarsh Saxena 34985deba9SUtkarsh Saxena 35985deba9SUtkarsh Saxenadef header_guard(filename): 36*dd3c26a0STobias Hieta """Returns the header guard for the generated header.""" 37985deba9SUtkarsh Saxena return "GENERATED_DECISION_FOREST_MODEL_%s_H" % filename.upper() 38985deba9SUtkarsh Saxena 39985deba9SUtkarsh Saxena 40985deba9SUtkarsh Saxenadef boost_node(n, label, next_label): 4145698ac0SUtkarsh Saxena """Returns code snippet for a leaf/boost node.""" 42*dd3c26a0STobias Hieta return "%s: return %sf;" % (label, n["score"]) 43985deba9SUtkarsh Saxena 44985deba9SUtkarsh Saxena 45985deba9SUtkarsh Saxenadef if_greater_node(n, label, next_label): 46985deba9SUtkarsh Saxena """Returns code snippet for a if_greater node. 47985deba9SUtkarsh Saxena Jumps to true_label if the Example feature (NUMBER) is greater than the threshold. 48985deba9SUtkarsh Saxena Comparing integers is much faster than comparing floats. Assuming floating points 49985deba9SUtkarsh Saxena are represented as IEEE 754, it order-encodes the floats to integers before comparing them. 50985deba9SUtkarsh Saxena Control falls through if condition is evaluated to false.""" 51985deba9SUtkarsh Saxena threshold = n["threshold"] 5245698ac0SUtkarsh Saxena return "%s: if (E.get%s() >= %s /*%s*/) goto %s;" % ( 53*dd3c26a0STobias Hieta label, 54*dd3c26a0STobias Hieta n["feature"], 55*dd3c26a0STobias Hieta order_encode(threshold), 56*dd3c26a0STobias Hieta threshold, 57*dd3c26a0STobias Hieta next_label, 58*dd3c26a0STobias Hieta ) 59985deba9SUtkarsh Saxena 60985deba9SUtkarsh Saxena 61985deba9SUtkarsh Saxenadef if_member_node(n, label, next_label): 62985deba9SUtkarsh Saxena """Returns code snippet for a if_member node. 63985deba9SUtkarsh Saxena Jumps to true_label if the Example feature (ENUM) is present in the set of enum values 64985deba9SUtkarsh Saxena described in the node. 65985deba9SUtkarsh Saxena Control falls through if condition is evaluated to false.""" 66*dd3c26a0STobias Hieta members = "|".join( 67*dd3c26a0STobias Hieta ["BIT(%s_type::%s)" % (n["feature"], member) for member in n["set"]] 68*dd3c26a0STobias Hieta ) 6945698ac0SUtkarsh Saxena return "%s: if (E.get%s() & (%s)) goto %s;" % ( 70*dd3c26a0STobias Hieta label, 71*dd3c26a0STobias Hieta n["feature"], 72*dd3c26a0STobias Hieta members, 73*dd3c26a0STobias Hieta next_label, 74*dd3c26a0STobias Hieta ) 75985deba9SUtkarsh Saxena 76985deba9SUtkarsh Saxena 77985deba9SUtkarsh Saxenadef node(n, label, next_label): 78985deba9SUtkarsh Saxena """Returns code snippet for the node.""" 79985deba9SUtkarsh Saxena return { 80*dd3c26a0STobias Hieta "boost": boost_node, 81*dd3c26a0STobias Hieta "if_greater": if_greater_node, 82*dd3c26a0STobias Hieta "if_member": if_member_node, 83*dd3c26a0STobias Hieta }[n["operation"]](n, label, next_label) 84985deba9SUtkarsh Saxena 85985deba9SUtkarsh Saxena 86985deba9SUtkarsh Saxenadef tree(t, tree_num, node_num): 87985deba9SUtkarsh Saxena """Returns code for inferencing a Decision Tree. 88985deba9SUtkarsh Saxena Also returns the size of the decision tree. 89985deba9SUtkarsh Saxena 90985deba9SUtkarsh Saxena A tree starts with its label `t{tree#}`. 91985deba9SUtkarsh Saxena A node of the tree starts with label `t{tree#}_n{node#}`. 92985deba9SUtkarsh Saxena 93985deba9SUtkarsh Saxena The tree contains two types of node: Conditional node and Leaf node. 94985deba9SUtkarsh Saxena - Conditional node evaluates a condition. If true, it jumps to the true node/child. 95985deba9SUtkarsh Saxena Code is generated using pre-order traversal of the tree considering 96985deba9SUtkarsh Saxena false node as the first child. Therefore the false node is always the 97985deba9SUtkarsh Saxena immediately next label. 98985deba9SUtkarsh Saxena - Leaf node adds the value to the score and jumps to the next tree. 99985deba9SUtkarsh Saxena """ 100985deba9SUtkarsh Saxena label = "t%d_n%d" % (tree_num, node_num) 101985deba9SUtkarsh Saxena code = [] 102985deba9SUtkarsh Saxena 103985deba9SUtkarsh Saxena if t["operation"] == "boost": 104985deba9SUtkarsh Saxena code.append(node(t, label=label, next_label="t%d" % (tree_num + 1))) 105985deba9SUtkarsh Saxena return code, 1 106985deba9SUtkarsh Saxena 107*dd3c26a0STobias Hieta false_code, false_size = tree(t["else"], tree_num=tree_num, node_num=node_num + 1) 108985deba9SUtkarsh Saxena 109985deba9SUtkarsh Saxena true_node_num = node_num + false_size + 1 110985deba9SUtkarsh Saxena true_label = "t%d_n%d" % (tree_num, true_node_num) 111985deba9SUtkarsh Saxena 112*dd3c26a0STobias Hieta true_code, true_size = tree(t["then"], tree_num=tree_num, node_num=true_node_num) 113985deba9SUtkarsh Saxena 114985deba9SUtkarsh Saxena code.append(node(t, label=label, next_label=true_label)) 115985deba9SUtkarsh Saxena 116985deba9SUtkarsh Saxena return code + false_code + true_code, 1 + false_size + true_size 117985deba9SUtkarsh Saxena 118985deba9SUtkarsh Saxena 119985deba9SUtkarsh Saxenadef gen_header_code(features_json, cpp_class, filename): 120985deba9SUtkarsh Saxena """Returns code for header declaring the inference runtime. 121985deba9SUtkarsh Saxena 122985deba9SUtkarsh Saxena Declares the Example class named {cpp_class} inside relevant namespaces. 123985deba9SUtkarsh Saxena The Example class contains all the features as class members. This 124985deba9SUtkarsh Saxena class can be used to represent a code completion candidate. 125985deba9SUtkarsh Saxena Provides `float Evaluate()` function which can be used to score the Example. 126985deba9SUtkarsh Saxena """ 127985deba9SUtkarsh Saxena setters = [] 12845698ac0SUtkarsh Saxena getters = [] 129985deba9SUtkarsh Saxena for f in features_json: 130985deba9SUtkarsh Saxena feature = f["name"] 13145698ac0SUtkarsh Saxena 132985deba9SUtkarsh Saxena if f["kind"] == "NUMBER": 133985deba9SUtkarsh Saxena # Floats are order-encoded to integers for faster comparison. 134985deba9SUtkarsh Saxena setters.append( 135*dd3c26a0STobias Hieta "void set%s(float V) { %s = OrderEncode(V); }" % (feature, feature) 136*dd3c26a0STobias Hieta ) 137985deba9SUtkarsh Saxena elif f["kind"] == "ENUM": 138985deba9SUtkarsh Saxena setters.append( 139*dd3c26a0STobias Hieta "void set%s(unsigned V) { %s = 1LL << V; }" % (feature, feature) 140*dd3c26a0STobias Hieta ) 141985deba9SUtkarsh Saxena else: 142985deba9SUtkarsh Saxena raise ValueError("Unhandled feature type.", f["kind"]) 143985deba9SUtkarsh Saxena 144985deba9SUtkarsh Saxena # Class members represent all the features of the Example. 14545698ac0SUtkarsh Saxena class_members = [ 146*dd3c26a0STobias Hieta "uint%d_t %s = 0;" % (64 if f["kind"] == "ENUM" else 32, f["name"]) 14745698ac0SUtkarsh Saxena for f in features_json 14845698ac0SUtkarsh Saxena ] 14945698ac0SUtkarsh Saxena getters = [ 150bf935a03SUtkarsh Saxena "LLVM_ATTRIBUTE_ALWAYS_INLINE uint%d_t get%s() const { return %s; }" 151*dd3c26a0STobias Hieta % (64 if f["kind"] == "ENUM" else 32, f["name"], f["name"]) 15245698ac0SUtkarsh Saxena for f in features_json 15345698ac0SUtkarsh Saxena ] 154985deba9SUtkarsh Saxena nline = "\n " 155985deba9SUtkarsh Saxena guard = header_guard(filename) 156985deba9SUtkarsh Saxena return """#ifndef %s 157985deba9SUtkarsh Saxena#define %s 158985deba9SUtkarsh Saxena#include <cstdint> 159a9f63d22SUtkarsh Saxena#include "llvm/Support/Compiler.h" 160985deba9SUtkarsh Saxena 161985deba9SUtkarsh Saxena%s 162985deba9SUtkarsh Saxenaclass %s { 163985deba9SUtkarsh Saxenapublic: 16445698ac0SUtkarsh Saxena // Setters. 16545698ac0SUtkarsh Saxena %s 16645698ac0SUtkarsh Saxena 16745698ac0SUtkarsh Saxena // Getters. 168985deba9SUtkarsh Saxena %s 169985deba9SUtkarsh Saxena 170985deba9SUtkarsh Saxenaprivate: 171985deba9SUtkarsh Saxena %s 172985deba9SUtkarsh Saxena 173985deba9SUtkarsh Saxena // Produces an integer that sorts in the same order as F. 174985deba9SUtkarsh Saxena // That is: a < b <==> orderEncode(a) < orderEncode(b). 175985deba9SUtkarsh Saxena static uint32_t OrderEncode(float F); 176985deba9SUtkarsh Saxena}; 177985deba9SUtkarsh Saxena 178985deba9SUtkarsh Saxenafloat Evaluate(const %s&); 179985deba9SUtkarsh Saxena%s 180985deba9SUtkarsh Saxena#endif // %s 181*dd3c26a0STobias Hieta""" % ( 182*dd3c26a0STobias Hieta guard, 183*dd3c26a0STobias Hieta guard, 184*dd3c26a0STobias Hieta cpp_class.ns_begin(), 185*dd3c26a0STobias Hieta cpp_class.name, 18645698ac0SUtkarsh Saxena nline.join(setters), 18745698ac0SUtkarsh Saxena nline.join(getters), 18845698ac0SUtkarsh Saxena nline.join(class_members), 189*dd3c26a0STobias Hieta cpp_class.name, 190*dd3c26a0STobias Hieta cpp_class.ns_end(), 191*dd3c26a0STobias Hieta guard, 192*dd3c26a0STobias Hieta ) 193985deba9SUtkarsh Saxena 194985deba9SUtkarsh Saxena 195985deba9SUtkarsh Saxenadef order_encode(v): 196*dd3c26a0STobias Hieta i = struct.unpack("<I", struct.pack("<f", v))[0] 197985deba9SUtkarsh Saxena TopBit = 1 << 31 198985deba9SUtkarsh Saxena # IEEE 754 floats compare like sign-magnitude integers. 199*dd3c26a0STobias Hieta if i & TopBit: # Negative float 200985deba9SUtkarsh Saxena return (1 << 32) - i # low half of integers, order reversed. 201985deba9SUtkarsh Saxena return TopBit + i # top half of integers 202985deba9SUtkarsh Saxena 203985deba9SUtkarsh Saxena 204985deba9SUtkarsh Saxenadef evaluate_func(forest_json, cpp_class): 20545698ac0SUtkarsh Saxena """Generates evaluation functions for each tree and combines them in 20645698ac0SUtkarsh Saxena `float Evaluate(const {Example}&)` function. This function can be 20745698ac0SUtkarsh Saxena used to score an Example.""" 20845698ac0SUtkarsh Saxena 20945698ac0SUtkarsh Saxena code = "" 21045698ac0SUtkarsh Saxena 21145698ac0SUtkarsh Saxena # Generate evaluation function of each tree. 21245698ac0SUtkarsh Saxena code += "namespace {\n" 213985deba9SUtkarsh Saxena tree_num = 0 214985deba9SUtkarsh Saxena for tree_json in forest_json: 215*dd3c26a0STobias Hieta code += "LLVM_ATTRIBUTE_NOINLINE float EvaluateTree%d(const %s& E) {\n" % ( 216*dd3c26a0STobias Hieta tree_num, 217*dd3c26a0STobias Hieta cpp_class.name, 218*dd3c26a0STobias Hieta ) 219*dd3c26a0STobias Hieta code += ( 220*dd3c26a0STobias Hieta " " + "\n ".join(tree(tree_json, tree_num=tree_num, node_num=0)[0]) + "\n" 221*dd3c26a0STobias Hieta ) 22245698ac0SUtkarsh Saxena code += "}\n\n" 223985deba9SUtkarsh Saxena tree_num += 1 22445698ac0SUtkarsh Saxena code += "} // namespace\n\n" 225985deba9SUtkarsh Saxena 22645698ac0SUtkarsh Saxena # Combine the scores of all trees in the final function. 22745698ac0SUtkarsh Saxena # MSAN will timeout if these functions are inlined. 22845698ac0SUtkarsh Saxena code += "float Evaluate(const %s& E) {\n" % cpp_class.name 22945698ac0SUtkarsh Saxena code += " float Score = 0;\n" 23045698ac0SUtkarsh Saxena for tree_num in range(len(forest_json)): 23145698ac0SUtkarsh Saxena code += " Score += EvaluateTree%d(E);\n" % tree_num 23245698ac0SUtkarsh Saxena code += " return Score;\n" 23345698ac0SUtkarsh Saxena code += "}\n" 23445698ac0SUtkarsh Saxena 235985deba9SUtkarsh Saxena return code 236985deba9SUtkarsh Saxena 237985deba9SUtkarsh Saxena 238985deba9SUtkarsh Saxenadef gen_cpp_code(forest_json, features_json, filename, cpp_class): 239985deba9SUtkarsh Saxena """Generates code for the .cpp file.""" 240985deba9SUtkarsh Saxena # Headers 241985deba9SUtkarsh Saxena # Required by OrderEncode(float F). 242*dd3c26a0STobias Hieta angled_include = ["#include <%s>" % h for h in ["cstring", "limits"]] 243985deba9SUtkarsh Saxena 244985deba9SUtkarsh Saxena # Include generated header. 245*dd3c26a0STobias Hieta qouted_headers = {filename + ".h", "llvm/ADT/bit.h"} 246985deba9SUtkarsh Saxena # Headers required by ENUM features used by the model. 247*dd3c26a0STobias Hieta qouted_headers |= {f["header"] for f in features_json if f["kind"] == "ENUM"} 248985deba9SUtkarsh Saxena quoted_include = ['#include "%s"' % h for h in sorted(qouted_headers)] 249985deba9SUtkarsh Saxena 250985deba9SUtkarsh Saxena # using-decl for ENUM features. 251*dd3c26a0STobias Hieta using_decls = "\n".join( 252*dd3c26a0STobias Hieta "using %s_type = %s;" % (feature["name"], feature["type"]) 253985deba9SUtkarsh Saxena for feature in features_json 254*dd3c26a0STobias Hieta if feature["kind"] == "ENUM" 255*dd3c26a0STobias Hieta ) 256985deba9SUtkarsh Saxena nl = "\n" 257985deba9SUtkarsh Saxena return """%s 258985deba9SUtkarsh Saxena 259985deba9SUtkarsh Saxena%s 260985deba9SUtkarsh Saxena 261bf935a03SUtkarsh Saxena#define BIT(X) (1LL << X) 262985deba9SUtkarsh Saxena 263985deba9SUtkarsh Saxena%s 264985deba9SUtkarsh Saxena 265985deba9SUtkarsh Saxena%s 266985deba9SUtkarsh Saxena 267985deba9SUtkarsh Saxenauint32_t %s::OrderEncode(float F) { 268985deba9SUtkarsh Saxena static_assert(std::numeric_limits<float>::is_iec559, ""); 269985deba9SUtkarsh Saxena constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1); 270985deba9SUtkarsh Saxena 271985deba9SUtkarsh Saxena // Get the bits of the float. Endianness is the same as for integers. 272985deba9SUtkarsh Saxena uint32_t U = llvm::bit_cast<uint32_t>(F); 273985deba9SUtkarsh Saxena std::memcpy(&U, &F, sizeof(U)); 274985deba9SUtkarsh Saxena // IEEE 754 floats compare like sign-magnitude integers. 275985deba9SUtkarsh Saxena if (U & TopBit) // Negative float. 276985deba9SUtkarsh Saxena return 0 - U; // Map onto the low half of integers, order reversed. 277985deba9SUtkarsh Saxena return U + TopBit; // Positive floats map onto the high half of integers. 278985deba9SUtkarsh Saxena} 279985deba9SUtkarsh Saxena 280985deba9SUtkarsh Saxena%s 281985deba9SUtkarsh Saxena%s 282*dd3c26a0STobias Hieta""" % ( 283*dd3c26a0STobias Hieta nl.join(angled_include), 284*dd3c26a0STobias Hieta nl.join(quoted_include), 285*dd3c26a0STobias Hieta cpp_class.ns_begin(), 286*dd3c26a0STobias Hieta using_decls, 287*dd3c26a0STobias Hieta cpp_class.name, 288*dd3c26a0STobias Hieta evaluate_func(forest_json, cpp_class), 289*dd3c26a0STobias Hieta cpp_class.ns_end(), 290*dd3c26a0STobias Hieta ) 291985deba9SUtkarsh Saxena 292985deba9SUtkarsh Saxena 293985deba9SUtkarsh Saxenadef main(): 294*dd3c26a0STobias Hieta parser = argparse.ArgumentParser("DecisionForestCodegen") 295*dd3c26a0STobias Hieta parser.add_argument("--filename", help="output file name.") 296*dd3c26a0STobias Hieta parser.add_argument("--output_dir", help="output directory.") 297*dd3c26a0STobias Hieta parser.add_argument("--model", help="path to model directory.") 298985deba9SUtkarsh Saxena parser.add_argument( 299*dd3c26a0STobias Hieta "--cpp_class", 300*dd3c26a0STobias Hieta help="The name of the class (which may be a namespace-qualified) created in generated header.", 301985deba9SUtkarsh Saxena ) 302985deba9SUtkarsh Saxena ns = parser.parse_args() 303985deba9SUtkarsh Saxena 304985deba9SUtkarsh Saxena output_dir = ns.output_dir 305985deba9SUtkarsh Saxena filename = ns.filename 306985deba9SUtkarsh Saxena header_file = "%s/%s.h" % (output_dir, filename) 307985deba9SUtkarsh Saxena cpp_file = "%s/%s.cpp" % (output_dir, filename) 308985deba9SUtkarsh Saxena cpp_class = CppClass(cpp_class=ns.cpp_class) 309985deba9SUtkarsh Saxena 310985deba9SUtkarsh Saxena model_file = "%s/forest.json" % ns.model 311985deba9SUtkarsh Saxena features_file = "%s/features.json" % ns.model 312985deba9SUtkarsh Saxena 313985deba9SUtkarsh Saxena with open(features_file) as f: 314985deba9SUtkarsh Saxena features_json = json.load(f) 315985deba9SUtkarsh Saxena 316985deba9SUtkarsh Saxena with open(model_file) as m: 317985deba9SUtkarsh Saxena forest_json = json.load(m) 318985deba9SUtkarsh Saxena 319*dd3c26a0STobias Hieta with open(cpp_file, "w+t") as output_cc: 320985deba9SUtkarsh Saxena output_cc.write( 321*dd3c26a0STobias Hieta gen_cpp_code( 322*dd3c26a0STobias Hieta forest_json=forest_json, 323985deba9SUtkarsh Saxena features_json=features_json, 324985deba9SUtkarsh Saxena filename=filename, 32545698ac0SUtkarsh Saxena cpp_class=cpp_class, 326*dd3c26a0STobias Hieta ) 327*dd3c26a0STobias Hieta ) 328*dd3c26a0STobias Hieta 329*dd3c26a0STobias Hieta with open(header_file, "w+t") as output_h: 330*dd3c26a0STobias Hieta output_h.write( 331*dd3c26a0STobias Hieta gen_header_code( 332*dd3c26a0STobias Hieta features_json=features_json, cpp_class=cpp_class, filename=filename 333*dd3c26a0STobias Hieta ) 334*dd3c26a0STobias Hieta ) 335985deba9SUtkarsh Saxena 336985deba9SUtkarsh Saxena 337*dd3c26a0STobias Hietaif __name__ == "__main__": 338985deba9SUtkarsh Saxena main() 339