xref: /llvm-project/llvm/lib/Analysis/models/gen-inline-oz-test-model.py (revision 600ff287722a15106e9a02c470b9865dda05980e)
1"""Generate a mock model for LLVM tests.
2
3The generated model is not a neural net - it is just a tf.function with the
4correct input and output parameters. By construction, the mock model will always
5output 1.
6"""
7
8import os
9import importlib.util
10import sys
11
12import tensorflow as tf
13
14POLICY_DECISION_LABEL = "inlining_decision"
15POLICY_OUTPUT_SPEC = """
16[
17    {
18        "logging_name": "inlining_decision",
19        "tensor_spec": {
20            "name": "StatefulPartitionedCall",
21            "port": 0,
22            "type": "int64_t",
23            "shape": [
24                1
25            ]
26        }
27    }
28]
29"""
30
31
32# pylint: disable=g-complex-comprehension
33def get_input_signature():
34    """Returns the list of features for LLVM inlining."""
35    # int64 features
36    inputs = [
37        tf.TensorSpec(dtype=tf.int64, shape=(), name=key)
38        for key in [
39            "caller_basic_block_count",
40            "caller_conditionally_executed_blocks",
41            "caller_users",
42            "callee_basic_block_count",
43            "callee_conditionally_executed_blocks",
44            "callee_users",
45            "nr_ctant_params",
46            "node_count",
47            "edge_count",
48            "callsite_height",
49            "cost_estimate",
50            "sroa_savings",
51            "sroa_losses",
52            "load_elimination",
53            "call_penalty",
54            "call_argument_setup",
55            "load_relative_intrinsic",
56            "lowered_call_arg_setup",
57            "indirect_call_penalty",
58            "jump_table_penalty",
59            "case_cluster_penalty",
60            "switch_penalty",
61            "unsimplified_common_instructions",
62            "num_loops",
63            "dead_blocks",
64            "simplified_instructions",
65            "constant_args",
66            "constant_offset_ptr_args",
67            "callsite_cost",
68            "cold_cc_penalty",
69            "last_call_to_static_bonus",
70            "is_multiple_blocks",
71            "nested_inlines",
72            "nested_inline_cost_estimate",
73            "threshold",
74            "is_callee_avail_external",
75            "is_caller_avail_external",
76        ]
77    ]
78
79    # float32 features
80    inputs.extend(
81        [
82            tf.TensorSpec(dtype=tf.float32, shape=(), name=key)
83            for key in ["discount", "reward"]
84        ]
85    )
86
87    # int32 features
88    inputs.extend(
89        [tf.TensorSpec(dtype=tf.int32, shape=(), name=key) for key in ["step_type"]]
90    )
91    return inputs
92
93
94def get_output_signature():
95    return POLICY_DECISION_LABEL
96
97
98def get_output_spec():
99    return POLICY_OUTPUT_SPEC
100
101
102def get_output_spec_path(path):
103    return os.path.join(path, "output_spec.json")
104
105
106def build_mock_model(path, signature, advice):
107    """Build and save the mock model with the given signature"""
108    module = tf.Module()
109
110    def action(*inputs):
111        return {signature["output"]: tf.constant(value=advice, dtype=tf.int64)}
112
113    module.action = tf.function()(action)
114    action = {"action": module.action.get_concrete_function(signature["inputs"])}
115    tf.saved_model.save(module, path, signatures=action)
116
117    output_spec_path = get_output_spec_path(path)
118    with open(output_spec_path, "w") as f:
119        print(f"Writing output spec to {output_spec_path}.")
120        f.write(signature["output_spec"])
121
122
123def get_signature():
124    return {
125        "inputs": get_input_signature(),
126        "output": get_output_signature(),
127        "output_spec": get_output_spec(),
128    }
129
130
131def main(argv):
132    assert len(argv) == 2 or (len(argv) == 3 and argv[2] == "never")
133    model_path = argv[1]
134
135    print(f"Output model to: [{argv[1]}]")
136
137    constant_advice = 1
138    if len(argv) == 3:
139        constant_advice = 0
140    print(f"The model will always return: {constant_advice}")
141
142    signature = get_signature()
143    build_mock_model(model_path, signature, constant_advice)
144
145
146if __name__ == "__main__":
147    main(sys.argv)
148