xref: /llvm-project/llvm/lib/Analysis/models/gen-regalloc-priority-test-model.py (revision b71edfaa4ec3c998aadb35255ce2f60bba2940b0)
1"""Generate a mock model for LLVM tests for Register Allocation.
2The generated model is not a neural net - it is just a tf.function with the
3correct input and output parameters.
4"""
5## By construction, the mock model will always output the first liverange that can be evicted.
6
7import os
8import sys
9import tensorflow as tf
10
11POLICY_DECISION_LABEL = "priority"
12POLICY_OUTPUT_SPEC = """
13[
14    {
15        "logging_name": "priority",
16        "tensor_spec": {
17            "name": "StatefulPartitionedCall",
18            "port": 0,
19            "type": "float",
20            "shape": [
21                1
22            ]
23        }
24    }
25]
26"""
27PER_LIVEINTERVAL_INT64_FEATURE_LIST = ["li_size", "stage"]
28PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST = ["weight"]
29PER_LIVEINTERVAL_FEATURE_LIST = (
30    PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST + PER_LIVEINTERVAL_INT64_FEATURE_LIST
31)
32CONTEXT_FEATURE_LIST = ("discount", "reward", "step_type")
33
34
35def get_input_signature():
36    """Returns (time_step_spec, action_spec) for LLVM register allocation."""
37    inputs = dict(
38        (key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
39        for key in PER_LIVEINTERVAL_INT64_FEATURE_LIST
40    )
41    inputs.update(
42        dict(
43            (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
44            for key in PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST
45        )
46    )
47    inputs.update(
48        dict(
49            (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
50            for key in ["discount", "reward"]
51        )
52    )
53    inputs.update(
54        dict(
55            (key, tf.TensorSpec(dtype=tf.int32, shape=(), name=key))
56            for key in ["step_type"]
57        )
58    )
59    return inputs
60
61
62def get_output_spec_path(path):
63    return os.path.join(path, "output_spec.json")
64
65
66def build_mock_model(path):
67    """Build and save the mock model with the given signature."""
68    module = tf.Module()
69    # We have to set this useless variable in order for the TF C API to correctly
70    # intake it
71    module.var = tf.Variable(0, dtype=tf.float32)
72
73    def action(*inputs):
74        s1 = tf.reduce_sum(
75            [
76                tf.cast(inputs[0][key], tf.float32)
77                for key in PER_LIVEINTERVAL_FEATURE_LIST
78            ],
79            axis=0,
80        )
81        s2 = tf.reduce_sum(
82            [tf.cast(inputs[0][key], tf.float32) for key in CONTEXT_FEATURE_LIST]
83        )
84        # Add a large number so s won't be 0.
85        s = s1 + s2
86        result = s + module.var
87        return {POLICY_DECISION_LABEL: result}
88
89    module.action = tf.function()(action)
90    action = {"action": module.action.get_concrete_function(get_input_signature())}
91
92    tf.saved_model.save(module, path, signatures=action)
93    output_spec_path = get_output_spec_path(path)
94    with open(output_spec_path, "w") as f:
95        print(f"Writing output spec to {output_spec_path}.")
96        f.write(POLICY_OUTPUT_SPEC)
97
98
99def main(argv):
100    assert len(argv) == 2
101    model_path = argv[1]
102    build_mock_model(model_path)
103
104
105if __name__ == "__main__":
106    main(sys.argv)
107