1"""Generate a mock model for LLVM tests. 2 3The generated model is not a neural net - it is just a tf.function with the 4correct input and output parameters. By construction, the mock model will always 5output 1. 6""" 7 8import os 9import importlib.util 10import sys 11 12import tensorflow as tf 13 14POLICY_DECISION_LABEL = 'inlining_decision' 15POLICY_OUTPUT_SPEC = """ 16[ 17 { 18 "logging_name": "inlining_decision", 19 "tensor_spec": { 20 "name": "StatefulPartitionedCall", 21 "port": 0, 22 "type": "int64_t", 23 "shape": [ 24 1 25 ] 26 } 27 } 28] 29""" 30 31 32# pylint: disable=g-complex-comprehension 33def get_input_signature(): 34 """Returns the list of features for LLVM inlining.""" 35 # int64 features 36 inputs = [ 37 tf.TensorSpec(dtype=tf.int64, shape=(), name=key) for key in [ 38 'caller_basic_block_count', 39 'caller_conditionally_executed_blocks', 40 'caller_users', 41 'callee_basic_block_count', 42 'callee_conditionally_executed_blocks', 43 'callee_users', 44 'nr_ctant_params', 45 'node_count', 46 'edge_count', 47 'callsite_height', 48 'cost_estimate', 49 'inlining_default', 50 'sroa_savings', 51 'sroa_losses', 52 'load_elimination', 53 'call_penalty', 54 'call_argument_setup', 55 'load_relative_intrinsic', 56 'lowered_call_arg_setup', 57 'indirect_call_penalty', 58 'jump_table_penalty', 59 'case_cluster_penalty', 60 'switch_penalty', 61 'unsimplified_common_instructions', 62 'num_loops', 63 'dead_blocks', 64 'simplified_instructions', 65 'constant_args', 66 'constant_offset_ptr_args', 67 'callsite_cost', 68 'cold_cc_penalty', 69 'last_call_to_static_bonus', 70 'is_multiple_blocks', 71 'nested_inlines', 72 'nested_inline_cost_estimate', 73 'threshold', 74 ] 75 ] 76 77 # float32 features 78 inputs.extend([ 79 tf.TensorSpec(dtype=tf.float32, shape=(), name=key) 80 for key in ['discount', 'reward'] 81 ]) 82 83 # int32 features 84 inputs.extend([ 85 tf.TensorSpec(dtype=tf.int32, shape=(), name=key) 86 for key in ['step_type'] 87 ]) 88 return inputs 89 90 91def get_output_signature(): 92 return POLICY_DECISION_LABEL 93 94 95def get_output_spec(): 96 return POLICY_OUTPUT_SPEC 97 98def get_output_spec_path(path): 99 return os.path.join(path, 'output_spec.json') 100 101 102def build_mock_model(path, signature): 103 """Build and save the mock model with the given signature""" 104 module = tf.Module() 105 def action(*inputs): 106 return {signature['output']: tf.constant(value=1, dtype=tf.int64)} 107 108 module.action = tf.function()(action) 109 action = {'action': module.action.get_concrete_function(signature['inputs'])} 110 tf.saved_model.save(module, path, signatures=action) 111 112 output_spec_path = get_output_spec_path(path) 113 with open(output_spec_path, 'w') as f: 114 print(f'Writing output spec to {output_spec_path}.') 115 f.write(signature['output_spec']) 116 117 118def get_signature(): 119 return { 120 'inputs': get_input_signature(), 121 'output': get_output_signature(), 122 'output_spec': get_output_spec() 123 } 124 125 126def main(argv): 127 assert len(argv) == 2 128 model_path = argv[1] 129 130 print(f'Output model to: [{argv[1]}]') 131 signature = get_signature() 132 build_mock_model(model_path, signature) 133 134 135if __name__ == '__main__': 136 main(sys.argv) 137