15b26f4f0SEric Wang"""Generate a mock model for LLVM tests for Register Allocation. 25b26f4f0SEric WangThe generated model is not a neural net - it is just a tf.function with the 35b26f4f0SEric Wangcorrect input and output parameters. 45b26f4f0SEric Wang""" 55b26f4f0SEric Wang## By construction, the mock model will always output the first liverange that can be evicted. 65b26f4f0SEric Wang 75b26f4f0SEric Wangimport os 85b26f4f0SEric Wangimport sys 95b26f4f0SEric Wangimport tensorflow as tf 10*b71edfaaSTobias Hieta 11*b71edfaaSTobias HietaPOLICY_DECISION_LABEL = "priority" 125b26f4f0SEric WangPOLICY_OUTPUT_SPEC = """ 135b26f4f0SEric Wang[ 145b26f4f0SEric Wang { 155b26f4f0SEric Wang "logging_name": "priority", 165b26f4f0SEric Wang "tensor_spec": { 175b26f4f0SEric Wang "name": "StatefulPartitionedCall", 185b26f4f0SEric Wang "port": 0, 195b26f4f0SEric Wang "type": "float", 205b26f4f0SEric Wang "shape": [ 215b26f4f0SEric Wang 1 225b26f4f0SEric Wang ] 235b26f4f0SEric Wang } 245b26f4f0SEric Wang } 255b26f4f0SEric Wang] 265b26f4f0SEric Wang""" 27*b71edfaaSTobias HietaPER_LIVEINTERVAL_INT64_FEATURE_LIST = ["li_size", "stage"] 28*b71edfaaSTobias HietaPER_LIVEINTERVAL_FLOAT32_FEATURE_LIST = ["weight"] 29*b71edfaaSTobias HietaPER_LIVEINTERVAL_FEATURE_LIST = ( 30*b71edfaaSTobias Hieta PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST + PER_LIVEINTERVAL_INT64_FEATURE_LIST 31*b71edfaaSTobias Hieta) 32*b71edfaaSTobias HietaCONTEXT_FEATURE_LIST = ("discount", "reward", "step_type") 335b26f4f0SEric Wang 345b26f4f0SEric Wang 355b26f4f0SEric Wangdef get_input_signature(): 365b26f4f0SEric Wang """Returns (time_step_spec, action_spec) for LLVM register allocation.""" 375b26f4f0SEric Wang inputs = dict( 385b26f4f0SEric Wang (key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key)) 39*b71edfaaSTobias Hieta for key in PER_LIVEINTERVAL_INT64_FEATURE_LIST 40*b71edfaaSTobias Hieta ) 415b26f4f0SEric Wang inputs.update( 42*b71edfaaSTobias Hieta dict( 43*b71edfaaSTobias Hieta (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key)) 44*b71edfaaSTobias Hieta for key in PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST 45*b71edfaaSTobias Hieta ) 46*b71edfaaSTobias Hieta ) 475b26f4f0SEric Wang inputs.update( 48*b71edfaaSTobias Hieta dict( 49*b71edfaaSTobias Hieta (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key)) 50*b71edfaaSTobias Hieta for key in ["discount", "reward"] 51*b71edfaaSTobias Hieta ) 52*b71edfaaSTobias Hieta ) 535b26f4f0SEric Wang inputs.update( 54*b71edfaaSTobias Hieta dict( 55*b71edfaaSTobias Hieta (key, tf.TensorSpec(dtype=tf.int32, shape=(), name=key)) 56*b71edfaaSTobias Hieta for key in ["step_type"] 57*b71edfaaSTobias Hieta ) 58*b71edfaaSTobias Hieta ) 595b26f4f0SEric Wang return inputs 605b26f4f0SEric Wang 615b26f4f0SEric Wang 625b26f4f0SEric Wangdef get_output_spec_path(path): 63*b71edfaaSTobias Hieta return os.path.join(path, "output_spec.json") 645b26f4f0SEric Wang 655b26f4f0SEric Wang 665b26f4f0SEric Wangdef build_mock_model(path): 675b26f4f0SEric Wang """Build and save the mock model with the given signature.""" 685b26f4f0SEric Wang module = tf.Module() 695b26f4f0SEric Wang # We have to set this useless variable in order for the TF C API to correctly 705b26f4f0SEric Wang # intake it 715b26f4f0SEric Wang module.var = tf.Variable(0, dtype=tf.float32) 725b26f4f0SEric Wang 735b26f4f0SEric Wang def action(*inputs): 74*b71edfaaSTobias Hieta s1 = tf.reduce_sum( 75*b71edfaaSTobias Hieta [ 76*b71edfaaSTobias Hieta tf.cast(inputs[0][key], tf.float32) 77*b71edfaaSTobias Hieta for key in PER_LIVEINTERVAL_FEATURE_LIST 785b26f4f0SEric Wang ], 79*b71edfaaSTobias Hieta axis=0, 80*b71edfaaSTobias Hieta ) 815b26f4f0SEric Wang s2 = tf.reduce_sum( 82*b71edfaaSTobias Hieta [tf.cast(inputs[0][key], tf.float32) for key in CONTEXT_FEATURE_LIST] 83*b71edfaaSTobias Hieta ) 845b26f4f0SEric Wang # Add a large number so s won't be 0. 855b26f4f0SEric Wang s = s1 + s2 865b26f4f0SEric Wang result = s + module.var 875b26f4f0SEric Wang return {POLICY_DECISION_LABEL: result} 88*b71edfaaSTobias Hieta 895b26f4f0SEric Wang module.action = tf.function()(action) 90*b71edfaaSTobias Hieta action = {"action": module.action.get_concrete_function(get_input_signature())} 915b26f4f0SEric Wang 925b26f4f0SEric Wang tf.saved_model.save(module, path, signatures=action) 935b26f4f0SEric Wang output_spec_path = get_output_spec_path(path) 94*b71edfaaSTobias Hieta with open(output_spec_path, "w") as f: 95*b71edfaaSTobias Hieta print(f"Writing output spec to {output_spec_path}.") 965b26f4f0SEric Wang f.write(POLICY_OUTPUT_SPEC) 975b26f4f0SEric Wang 985b26f4f0SEric Wang 995b26f4f0SEric Wangdef main(argv): 1005b26f4f0SEric Wang assert len(argv) == 2 1015b26f4f0SEric Wang model_path = argv[1] 1025b26f4f0SEric Wang build_mock_model(model_path) 1035b26f4f0SEric Wang 1045b26f4f0SEric Wang 105*b71edfaaSTobias Hietaif __name__ == "__main__": 1065b26f4f0SEric Wang main(sys.argv) 107