xref: /llvm-project/llvm/lib/Analysis/models/interactive_host.py (revision 65b40f273f09a53f61a13ac6f4bb65ec4ac63d6e)
15fd51fcbSMircea Trofin"""Utility for testing InteractiveModelRunner.
25fd51fcbSMircea Trofin
35fd51fcbSMircea TrofinUse it from pass-specific tests by providing a main .py which calls this library's
45fd51fcbSMircea Trofin`run_interactive` with an appropriate callback to provide advice.
55fd51fcbSMircea Trofin
65fd51fcbSMircea TrofinFrom .ll tests, just call the above-mentioned main as a prefix to the opt/llc
75fd51fcbSMircea Trofininvocation (with the appropriate flags enabling the interactive mode)
85fd51fcbSMircea Trofin
95fd51fcbSMircea TrofinExamples:
105fd51fcbSMircea Trofintest/Transforms/Inline/ML/interactive-mode.ll
11*65b40f27SMatt Arsenaulttest/CodeGen/MLRegAlloc/interactive-mode.ll
125fd51fcbSMircea Trofin"""
135fd51fcbSMircea Trofin
145fd51fcbSMircea Trofinimport ctypes
155fd51fcbSMircea Trofinimport log_reader
165fd51fcbSMircea Trofinimport io
175fd51fcbSMircea Trofinimport math
185fd51fcbSMircea Trofinimport os
195fd51fcbSMircea Trofinimport subprocess
2079f7a5e0SMircea Trofinfrom typing import Callable, List, Union
215fd51fcbSMircea Trofin
225fd51fcbSMircea Trofin
23b71edfaaSTobias Hietadef send(f: io.BufferedWriter, value: Union[int, float], spec: log_reader.TensorSpec):
245fd51fcbSMircea Trofin    """Send the `value` - currently just a scalar - formatted as per `spec`."""
255fd51fcbSMircea Trofin
265fd51fcbSMircea Trofin    # just int64 for now
27b71edfaaSTobias Hieta    assert spec.element_type == ctypes.c_int64
285fd51fcbSMircea Trofin    to_send = ctypes.c_int64(int(value))
29b71edfaaSTobias Hieta    assert f.write(bytes(to_send)) == ctypes.sizeof(spec.element_type) * math.prod(
30b71edfaaSTobias Hieta        spec.shape
31b71edfaaSTobias Hieta    )
325fd51fcbSMircea Trofin    f.flush()
335fd51fcbSMircea Trofin
345fd51fcbSMircea Trofin
35b71edfaaSTobias Hietadef run_interactive(
36b71edfaaSTobias Hieta    temp_rootname: str,
37b71edfaaSTobias Hieta    make_response: Callable[[List[log_reader.TensorValue]], Union[int, float]],
38b71edfaaSTobias Hieta    process_and_args: List[str],
39b71edfaaSTobias Hieta):
405fd51fcbSMircea Trofin    """Host the compiler.
415fd51fcbSMircea Trofin    Args:
425fd51fcbSMircea Trofin      temp_rootname: the base file name from which to construct the 2 pipes for
435fd51fcbSMircea Trofin      communicating with the compiler.
445fd51fcbSMircea Trofin      make_response: a function that, given the current tensor values, provides a
455fd51fcbSMircea Trofin      response.
465fd51fcbSMircea Trofin      process_and_args: the full commandline for the compiler. It it assumed it
475fd51fcbSMircea Trofin      contains a flag poiting to `temp_rootname` so that the InteractiveModeRunner
485fd51fcbSMircea Trofin      would attempt communication on the same pair as this function opens.
495fd51fcbSMircea Trofin
505fd51fcbSMircea Trofin    This function sets up the communication with the compiler - via 2 files named
515fd51fcbSMircea Trofin    `temp_rootname`.in and `temp_rootname`.out - prints out the received features,
525fd51fcbSMircea Trofin    and sends back to the compiler an advice (which it gets from `make_response`).
535fd51fcbSMircea Trofin    It's used for testing, and also to showcase how to set up communication in an
545fd51fcbSMircea Trofin    interactive ML ("gym") environment.
555fd51fcbSMircea Trofin    """
565fd51fcbSMircea Trofin    to_compiler = temp_rootname + ".in"
575fd51fcbSMircea Trofin    from_compiler = temp_rootname + ".out"
585fd51fcbSMircea Trofin    try:
595fd51fcbSMircea Trofin        os.mkfifo(to_compiler, 0o666)
605fd51fcbSMircea Trofin        os.mkfifo(from_compiler, 0o666)
615fd51fcbSMircea Trofin        compiler_proc = subprocess.Popen(
62b71edfaaSTobias Hieta            process_and_args, stderr=subprocess.PIPE, stdout=subprocess.DEVNULL
63b71edfaaSTobias Hieta        )
64b71edfaaSTobias Hieta        with io.BufferedWriter(io.FileIO(to_compiler, "wb")) as tc:
65b71edfaaSTobias Hieta            with io.BufferedReader(io.FileIO(from_compiler, "rb")) as fc:
665fd51fcbSMircea Trofin                tensor_specs, _, advice_spec = log_reader.read_header(fc)
675fd51fcbSMircea Trofin                context = None
68d62cdfadSMircea Trofin                while compiler_proc.poll() is None:
69d62cdfadSMircea Trofin                    next_event = fc.readline()
70d62cdfadSMircea Trofin                    if not next_event:
71d62cdfadSMircea Trofin                        break
72b71edfaaSTobias Hieta                    (
73b71edfaaSTobias Hieta                        last_context,
74b71edfaaSTobias Hieta                        observation_id,
75b71edfaaSTobias Hieta                        features,
76b71edfaaSTobias Hieta                        _,
77b71edfaaSTobias Hieta                    ) = log_reader.read_one_observation(
78b71edfaaSTobias Hieta                        context, next_event, fc, tensor_specs, None
79b71edfaaSTobias Hieta                    )
805fd51fcbSMircea Trofin                    if last_context != context:
81b71edfaaSTobias Hieta                        print(f"context: {last_context}")
825fd51fcbSMircea Trofin                    context = last_context
83b71edfaaSTobias Hieta                    print(f"observation: {observation_id}")
845fd51fcbSMircea Trofin                    tensor_values = []
855fd51fcbSMircea Trofin                    for fv in features:
865fd51fcbSMircea Trofin                        log_reader.pretty_print_tensor_value(fv)
875fd51fcbSMircea Trofin                        tensor_values.append(fv)
885fd51fcbSMircea Trofin                    send(tc, make_response(tensor_values), advice_spec)
895fd51fcbSMircea Trofin        _, err = compiler_proc.communicate()
90b71edfaaSTobias Hieta        print(err.decode("utf-8"))
915fd51fcbSMircea Trofin        compiler_proc.wait()
925fd51fcbSMircea Trofin
935fd51fcbSMircea Trofin    finally:
945fd51fcbSMircea Trofin        os.unlink(to_compiler)
955fd51fcbSMircea Trofin        os.unlink(from_compiler)
96