xref: /llvm-project/llvm/lib/Analysis/models/gen-regalloc-priority-test-model.py (revision b71edfaa4ec3c998aadb35255ce2f60bba2940b0)
15b26f4f0SEric Wang"""Generate a mock model for LLVM tests for Register Allocation.
25b26f4f0SEric WangThe generated model is not a neural net - it is just a tf.function with the
35b26f4f0SEric Wangcorrect input and output parameters.
45b26f4f0SEric Wang"""
55b26f4f0SEric Wang## By construction, the mock model will always output the first liverange that can be evicted.
65b26f4f0SEric Wang
75b26f4f0SEric Wangimport os
85b26f4f0SEric Wangimport sys
95b26f4f0SEric Wangimport tensorflow as tf
10*b71edfaaSTobias Hieta
11*b71edfaaSTobias HietaPOLICY_DECISION_LABEL = "priority"
125b26f4f0SEric WangPOLICY_OUTPUT_SPEC = """
135b26f4f0SEric Wang[
145b26f4f0SEric Wang    {
155b26f4f0SEric Wang        "logging_name": "priority",
165b26f4f0SEric Wang        "tensor_spec": {
175b26f4f0SEric Wang            "name": "StatefulPartitionedCall",
185b26f4f0SEric Wang            "port": 0,
195b26f4f0SEric Wang            "type": "float",
205b26f4f0SEric Wang            "shape": [
215b26f4f0SEric Wang                1
225b26f4f0SEric Wang            ]
235b26f4f0SEric Wang        }
245b26f4f0SEric Wang    }
255b26f4f0SEric Wang]
265b26f4f0SEric Wang"""
27*b71edfaaSTobias HietaPER_LIVEINTERVAL_INT64_FEATURE_LIST = ["li_size", "stage"]
28*b71edfaaSTobias HietaPER_LIVEINTERVAL_FLOAT32_FEATURE_LIST = ["weight"]
29*b71edfaaSTobias HietaPER_LIVEINTERVAL_FEATURE_LIST = (
30*b71edfaaSTobias Hieta    PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST + PER_LIVEINTERVAL_INT64_FEATURE_LIST
31*b71edfaaSTobias Hieta)
32*b71edfaaSTobias HietaCONTEXT_FEATURE_LIST = ("discount", "reward", "step_type")
335b26f4f0SEric Wang
345b26f4f0SEric Wang
355b26f4f0SEric Wangdef get_input_signature():
365b26f4f0SEric Wang    """Returns (time_step_spec, action_spec) for LLVM register allocation."""
375b26f4f0SEric Wang    inputs = dict(
385b26f4f0SEric Wang        (key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
39*b71edfaaSTobias Hieta        for key in PER_LIVEINTERVAL_INT64_FEATURE_LIST
40*b71edfaaSTobias Hieta    )
415b26f4f0SEric Wang    inputs.update(
42*b71edfaaSTobias Hieta        dict(
43*b71edfaaSTobias Hieta            (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
44*b71edfaaSTobias Hieta            for key in PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST
45*b71edfaaSTobias Hieta        )
46*b71edfaaSTobias Hieta    )
475b26f4f0SEric Wang    inputs.update(
48*b71edfaaSTobias Hieta        dict(
49*b71edfaaSTobias Hieta            (key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
50*b71edfaaSTobias Hieta            for key in ["discount", "reward"]
51*b71edfaaSTobias Hieta        )
52*b71edfaaSTobias Hieta    )
535b26f4f0SEric Wang    inputs.update(
54*b71edfaaSTobias Hieta        dict(
55*b71edfaaSTobias Hieta            (key, tf.TensorSpec(dtype=tf.int32, shape=(), name=key))
56*b71edfaaSTobias Hieta            for key in ["step_type"]
57*b71edfaaSTobias Hieta        )
58*b71edfaaSTobias Hieta    )
595b26f4f0SEric Wang    return inputs
605b26f4f0SEric Wang
615b26f4f0SEric Wang
625b26f4f0SEric Wangdef get_output_spec_path(path):
63*b71edfaaSTobias Hieta    return os.path.join(path, "output_spec.json")
645b26f4f0SEric Wang
655b26f4f0SEric Wang
665b26f4f0SEric Wangdef build_mock_model(path):
675b26f4f0SEric Wang    """Build and save the mock model with the given signature."""
685b26f4f0SEric Wang    module = tf.Module()
695b26f4f0SEric Wang    # We have to set this useless variable in order for the TF C API to correctly
705b26f4f0SEric Wang    # intake it
715b26f4f0SEric Wang    module.var = tf.Variable(0, dtype=tf.float32)
725b26f4f0SEric Wang
735b26f4f0SEric Wang    def action(*inputs):
74*b71edfaaSTobias Hieta        s1 = tf.reduce_sum(
75*b71edfaaSTobias Hieta            [
76*b71edfaaSTobias Hieta                tf.cast(inputs[0][key], tf.float32)
77*b71edfaaSTobias Hieta                for key in PER_LIVEINTERVAL_FEATURE_LIST
785b26f4f0SEric Wang            ],
79*b71edfaaSTobias Hieta            axis=0,
80*b71edfaaSTobias Hieta        )
815b26f4f0SEric Wang        s2 = tf.reduce_sum(
82*b71edfaaSTobias Hieta            [tf.cast(inputs[0][key], tf.float32) for key in CONTEXT_FEATURE_LIST]
83*b71edfaaSTobias Hieta        )
845b26f4f0SEric Wang        # Add a large number so s won't be 0.
855b26f4f0SEric Wang        s = s1 + s2
865b26f4f0SEric Wang        result = s + module.var
875b26f4f0SEric Wang        return {POLICY_DECISION_LABEL: result}
88*b71edfaaSTobias Hieta
895b26f4f0SEric Wang    module.action = tf.function()(action)
90*b71edfaaSTobias Hieta    action = {"action": module.action.get_concrete_function(get_input_signature())}
915b26f4f0SEric Wang
925b26f4f0SEric Wang    tf.saved_model.save(module, path, signatures=action)
935b26f4f0SEric Wang    output_spec_path = get_output_spec_path(path)
94*b71edfaaSTobias Hieta    with open(output_spec_path, "w") as f:
95*b71edfaaSTobias Hieta        print(f"Writing output spec to {output_spec_path}.")
965b26f4f0SEric Wang        f.write(POLICY_OUTPUT_SPEC)
975b26f4f0SEric Wang
985b26f4f0SEric Wang
995b26f4f0SEric Wangdef main(argv):
1005b26f4f0SEric Wang    assert len(argv) == 2
1015b26f4f0SEric Wang    model_path = argv[1]
1025b26f4f0SEric Wang    build_mock_model(model_path)
1035b26f4f0SEric Wang
1045b26f4f0SEric Wang
105*b71edfaaSTobias Hietaif __name__ == "__main__":
1065b26f4f0SEric Wang    main(sys.argv)
107