1"""Generate a mock model for LLVM tests for Register Allocation. 2The generated model is not a neural net - it is just a tf.function with the 3correct input and output parameters. 4""" 5## By construction, the mock model will always output the first liverange that can be evicted. 6 7import os 8import sys 9import tensorflow as tf 10 11POLICY_DECISION_LABEL = "priority" 12POLICY_OUTPUT_SPEC = """ 13[ 14 { 15 "logging_name": "priority", 16 "tensor_spec": { 17 "name": "StatefulPartitionedCall", 18 "port": 0, 19 "type": "float", 20 "shape": [ 21 1 22 ] 23 } 24 } 25] 26""" 27PER_LIVEINTERVAL_INT64_FEATURE_LIST = ["li_size", "stage"] 28PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST = ["weight"] 29PER_LIVEINTERVAL_FEATURE_LIST = ( 30 PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST + PER_LIVEINTERVAL_INT64_FEATURE_LIST 31) 32CONTEXT_FEATURE_LIST = ("discount", "reward", "step_type") 33 34 35def get_input_signature(): 36 """Returns (time_step_spec, action_spec) for LLVM register allocation.""" 37 inputs = dict( 38 (key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key)) 39 for key in PER_LIVEINTERVAL_INT64_FEATURE_LIST 40 ) 41 inputs.update( 42 dict( 43 (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key)) 44 for key in PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST 45 ) 46 ) 47 inputs.update( 48 dict( 49 (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key)) 50 for key in ["discount", "reward"] 51 ) 52 ) 53 inputs.update( 54 dict( 55 (key, tf.TensorSpec(dtype=tf.int32, shape=(), name=key)) 56 for key in ["step_type"] 57 ) 58 ) 59 return inputs 60 61 62def get_output_spec_path(path): 63 return os.path.join(path, "output_spec.json") 64 65 66def build_mock_model(path): 67 """Build and save the mock model with the given signature.""" 68 module = tf.Module() 69 # We have to set this useless variable in order for the TF C API to correctly 70 # intake it 71 module.var = tf.Variable(0, dtype=tf.float32) 72 73 def action(*inputs): 74 s1 = tf.reduce_sum( 75 [ 76 tf.cast(inputs[0][key], tf.float32) 77 for key in PER_LIVEINTERVAL_FEATURE_LIST 78 ], 79 axis=0, 80 ) 81 s2 = tf.reduce_sum( 82 [tf.cast(inputs[0][key], tf.float32) for key in CONTEXT_FEATURE_LIST] 83 ) 84 # Add a large number so s won't be 0. 85 s = s1 + s2 86 result = s + module.var 87 return {POLICY_DECISION_LABEL: result} 88 89 module.action = tf.function()(action) 90 action = {"action": module.action.get_concrete_function(get_input_signature())} 91 92 tf.saved_model.save(module, path, signatures=action) 93 output_spec_path = get_output_spec_path(path) 94 with open(output_spec_path, "w") as f: 95 print(f"Writing output spec to {output_spec_path}.") 96 f.write(POLICY_OUTPUT_SPEC) 97 98 99def main(argv): 100 assert len(argv) == 2 101 model_path = argv[1] 102 build_mock_model(model_path) 103 104 105if __name__ == "__main__": 106 main(sys.argv) 107