15fd51fcbSMircea Trofin"""Utility for testing InteractiveModelRunner. 25fd51fcbSMircea Trofin 35fd51fcbSMircea TrofinUse it from pass-specific tests by providing a main .py which calls this library's 45fd51fcbSMircea Trofin`run_interactive` with an appropriate callback to provide advice. 55fd51fcbSMircea Trofin 65fd51fcbSMircea TrofinFrom .ll tests, just call the above-mentioned main as a prefix to the opt/llc 75fd51fcbSMircea Trofininvocation (with the appropriate flags enabling the interactive mode) 85fd51fcbSMircea Trofin 95fd51fcbSMircea TrofinExamples: 105fd51fcbSMircea Trofintest/Transforms/Inline/ML/interactive-mode.ll 11*65b40f27SMatt Arsenaulttest/CodeGen/MLRegAlloc/interactive-mode.ll 125fd51fcbSMircea Trofin""" 135fd51fcbSMircea Trofin 145fd51fcbSMircea Trofinimport ctypes 155fd51fcbSMircea Trofinimport log_reader 165fd51fcbSMircea Trofinimport io 175fd51fcbSMircea Trofinimport math 185fd51fcbSMircea Trofinimport os 195fd51fcbSMircea Trofinimport subprocess 2079f7a5e0SMircea Trofinfrom typing import Callable, List, Union 215fd51fcbSMircea Trofin 225fd51fcbSMircea Trofin 23b71edfaaSTobias Hietadef send(f: io.BufferedWriter, value: Union[int, float], spec: log_reader.TensorSpec): 245fd51fcbSMircea Trofin """Send the `value` - currently just a scalar - formatted as per `spec`.""" 255fd51fcbSMircea Trofin 265fd51fcbSMircea Trofin # just int64 for now 27b71edfaaSTobias Hieta assert spec.element_type == ctypes.c_int64 285fd51fcbSMircea Trofin to_send = ctypes.c_int64(int(value)) 29b71edfaaSTobias Hieta assert f.write(bytes(to_send)) == ctypes.sizeof(spec.element_type) * math.prod( 30b71edfaaSTobias Hieta spec.shape 31b71edfaaSTobias Hieta ) 325fd51fcbSMircea Trofin f.flush() 335fd51fcbSMircea Trofin 345fd51fcbSMircea Trofin 35b71edfaaSTobias Hietadef run_interactive( 36b71edfaaSTobias Hieta temp_rootname: str, 37b71edfaaSTobias Hieta make_response: Callable[[List[log_reader.TensorValue]], Union[int, float]], 38b71edfaaSTobias Hieta process_and_args: List[str], 39b71edfaaSTobias Hieta): 405fd51fcbSMircea Trofin """Host the compiler. 415fd51fcbSMircea Trofin Args: 425fd51fcbSMircea Trofin temp_rootname: the base file name from which to construct the 2 pipes for 435fd51fcbSMircea Trofin communicating with the compiler. 445fd51fcbSMircea Trofin make_response: a function that, given the current tensor values, provides a 455fd51fcbSMircea Trofin response. 465fd51fcbSMircea Trofin process_and_args: the full commandline for the compiler. It it assumed it 475fd51fcbSMircea Trofin contains a flag poiting to `temp_rootname` so that the InteractiveModeRunner 485fd51fcbSMircea Trofin would attempt communication on the same pair as this function opens. 495fd51fcbSMircea Trofin 505fd51fcbSMircea Trofin This function sets up the communication with the compiler - via 2 files named 515fd51fcbSMircea Trofin `temp_rootname`.in and `temp_rootname`.out - prints out the received features, 525fd51fcbSMircea Trofin and sends back to the compiler an advice (which it gets from `make_response`). 535fd51fcbSMircea Trofin It's used for testing, and also to showcase how to set up communication in an 545fd51fcbSMircea Trofin interactive ML ("gym") environment. 555fd51fcbSMircea Trofin """ 565fd51fcbSMircea Trofin to_compiler = temp_rootname + ".in" 575fd51fcbSMircea Trofin from_compiler = temp_rootname + ".out" 585fd51fcbSMircea Trofin try: 595fd51fcbSMircea Trofin os.mkfifo(to_compiler, 0o666) 605fd51fcbSMircea Trofin os.mkfifo(from_compiler, 0o666) 615fd51fcbSMircea Trofin compiler_proc = subprocess.Popen( 62b71edfaaSTobias Hieta process_and_args, stderr=subprocess.PIPE, stdout=subprocess.DEVNULL 63b71edfaaSTobias Hieta ) 64b71edfaaSTobias Hieta with io.BufferedWriter(io.FileIO(to_compiler, "wb")) as tc: 65b71edfaaSTobias Hieta with io.BufferedReader(io.FileIO(from_compiler, "rb")) as fc: 665fd51fcbSMircea Trofin tensor_specs, _, advice_spec = log_reader.read_header(fc) 675fd51fcbSMircea Trofin context = None 68d62cdfadSMircea Trofin while compiler_proc.poll() is None: 69d62cdfadSMircea Trofin next_event = fc.readline() 70d62cdfadSMircea Trofin if not next_event: 71d62cdfadSMircea Trofin break 72b71edfaaSTobias Hieta ( 73b71edfaaSTobias Hieta last_context, 74b71edfaaSTobias Hieta observation_id, 75b71edfaaSTobias Hieta features, 76b71edfaaSTobias Hieta _, 77b71edfaaSTobias Hieta ) = log_reader.read_one_observation( 78b71edfaaSTobias Hieta context, next_event, fc, tensor_specs, None 79b71edfaaSTobias Hieta ) 805fd51fcbSMircea Trofin if last_context != context: 81b71edfaaSTobias Hieta print(f"context: {last_context}") 825fd51fcbSMircea Trofin context = last_context 83b71edfaaSTobias Hieta print(f"observation: {observation_id}") 845fd51fcbSMircea Trofin tensor_values = [] 855fd51fcbSMircea Trofin for fv in features: 865fd51fcbSMircea Trofin log_reader.pretty_print_tensor_value(fv) 875fd51fcbSMircea Trofin tensor_values.append(fv) 885fd51fcbSMircea Trofin send(tc, make_response(tensor_values), advice_spec) 895fd51fcbSMircea Trofin _, err = compiler_proc.communicate() 90b71edfaaSTobias Hieta print(err.decode("utf-8")) 915fd51fcbSMircea Trofin compiler_proc.wait() 925fd51fcbSMircea Trofin 935fd51fcbSMircea Trofin finally: 945fd51fcbSMircea Trofin os.unlink(to_compiler) 955fd51fcbSMircea Trofin os.unlink(from_compiler) 96