1"""Generate a mock model for LLVM tests. 2 3The generated model is not a neural net - it is just a tf.function with the 4correct input and output parameters. By construction, the mock model will always 5output 1. 6""" 7 8import os 9import importlib.util 10import sys 11 12import tensorflow as tf 13 14POLICY_DECISION_LABEL = "inlining_decision" 15POLICY_OUTPUT_SPEC = """ 16[ 17 { 18 "logging_name": "inlining_decision", 19 "tensor_spec": { 20 "name": "StatefulPartitionedCall", 21 "port": 0, 22 "type": "int64_t", 23 "shape": [ 24 1 25 ] 26 } 27 } 28] 29""" 30 31 32# pylint: disable=g-complex-comprehension 33def get_input_signature(): 34 """Returns the list of features for LLVM inlining.""" 35 # int64 features 36 inputs = [ 37 tf.TensorSpec(dtype=tf.int64, shape=(), name=key) 38 for key in [ 39 "caller_basic_block_count", 40 "caller_conditionally_executed_blocks", 41 "caller_users", 42 "callee_basic_block_count", 43 "callee_conditionally_executed_blocks", 44 "callee_users", 45 "nr_ctant_params", 46 "node_count", 47 "edge_count", 48 "callsite_height", 49 "cost_estimate", 50 "sroa_savings", 51 "sroa_losses", 52 "load_elimination", 53 "call_penalty", 54 "call_argument_setup", 55 "load_relative_intrinsic", 56 "lowered_call_arg_setup", 57 "indirect_call_penalty", 58 "jump_table_penalty", 59 "case_cluster_penalty", 60 "switch_penalty", 61 "unsimplified_common_instructions", 62 "num_loops", 63 "dead_blocks", 64 "simplified_instructions", 65 "constant_args", 66 "constant_offset_ptr_args", 67 "callsite_cost", 68 "cold_cc_penalty", 69 "last_call_to_static_bonus", 70 "is_multiple_blocks", 71 "nested_inlines", 72 "nested_inline_cost_estimate", 73 "threshold", 74 "is_callee_avail_external", 75 "is_caller_avail_external", 76 ] 77 ] 78 79 # float32 features 80 inputs.extend( 81 [ 82 tf.TensorSpec(dtype=tf.float32, shape=(), name=key) 83 for key in ["discount", "reward"] 84 ] 85 ) 86 87 # int32 features 88 inputs.extend( 89 [tf.TensorSpec(dtype=tf.int32, shape=(), name=key) for key in ["step_type"]] 90 ) 91 return inputs 92 93 94def get_output_signature(): 95 return POLICY_DECISION_LABEL 96 97 98def get_output_spec(): 99 return POLICY_OUTPUT_SPEC 100 101 102def get_output_spec_path(path): 103 return os.path.join(path, "output_spec.json") 104 105 106def build_mock_model(path, signature, advice): 107 """Build and save the mock model with the given signature""" 108 module = tf.Module() 109 110 def action(*inputs): 111 return {signature["output"]: tf.constant(value=advice, dtype=tf.int64)} 112 113 module.action = tf.function()(action) 114 action = {"action": module.action.get_concrete_function(signature["inputs"])} 115 tf.saved_model.save(module, path, signatures=action) 116 117 output_spec_path = get_output_spec_path(path) 118 with open(output_spec_path, "w") as f: 119 print(f"Writing output spec to {output_spec_path}.") 120 f.write(signature["output_spec"]) 121 122 123def get_signature(): 124 return { 125 "inputs": get_input_signature(), 126 "output": get_output_signature(), 127 "output_spec": get_output_spec(), 128 } 129 130 131def main(argv): 132 assert len(argv) == 2 or (len(argv) == 3 and argv[2] == "never") 133 model_path = argv[1] 134 135 print(f"Output model to: [{argv[1]}]") 136 137 constant_advice = 1 138 if len(argv) == 3: 139 constant_advice = 0 140 print(f"The model will always return: {constant_advice}") 141 142 signature = get_signature() 143 build_mock_model(model_path, signature, constant_advice) 144 145 146if __name__ == "__main__": 147 main(sys.argv) 148