14c97745bSMircea Trofin"""Reader for training log. 24c97745bSMircea Trofin 34c97745bSMircea TrofinSee lib/Analysis/TrainingLogger.cpp for a description of the format. 44c97745bSMircea Trofin""" 54c97745bSMircea Trofinimport ctypes 64c97745bSMircea Trofinimport dataclasses 7954cf9a7SMircea Trofinimport io 84c97745bSMircea Trofinimport json 94c97745bSMircea Trofinimport math 104c97745bSMircea Trofinimport sys 11b72e893dSMircea Trofinfrom typing import List, Optional 124c97745bSMircea Trofin 134c97745bSMircea Trofin_element_types = { 14*b71edfaaSTobias Hieta "float": ctypes.c_float, 15*b71edfaaSTobias Hieta "double": ctypes.c_double, 16*b71edfaaSTobias Hieta "int8_t": ctypes.c_int8, 17*b71edfaaSTobias Hieta "uint8_t": ctypes.c_uint8, 18*b71edfaaSTobias Hieta "int16_t": ctypes.c_int16, 19*b71edfaaSTobias Hieta "uint16_t": ctypes.c_uint16, 20*b71edfaaSTobias Hieta "int32_t": ctypes.c_int32, 21*b71edfaaSTobias Hieta "uint32_t": ctypes.c_uint32, 22*b71edfaaSTobias Hieta "int64_t": ctypes.c_int64, 23*b71edfaaSTobias Hieta "uint64_t": ctypes.c_uint64, 244c97745bSMircea Trofin} 254c97745bSMircea Trofin 264c97745bSMircea Trofin 274c97745bSMircea Trofin@dataclasses.dataclass(frozen=True) 284c97745bSMircea Trofinclass TensorSpec: 294c97745bSMircea Trofin name: str 304c97745bSMircea Trofin port: int 31b72e893dSMircea Trofin shape: List[int] 324c97745bSMircea Trofin element_type: type 334c97745bSMircea Trofin 344c97745bSMircea Trofin @staticmethod 354c97745bSMircea Trofin def from_dict(d: dict): 36*b71edfaaSTobias Hieta name = d["name"] 37*b71edfaaSTobias Hieta port = d["port"] 38*b71edfaaSTobias Hieta shape = [int(e) for e in d["shape"]] 39*b71edfaaSTobias Hieta element_type_str = d["type"] 404c97745bSMircea Trofin if element_type_str not in _element_types: 41*b71edfaaSTobias Hieta raise ValueError(f"uknown type: {element_type_str}") 424c97745bSMircea Trofin return TensorSpec( 434c97745bSMircea Trofin name=name, 444c97745bSMircea Trofin port=port, 454c97745bSMircea Trofin shape=shape, 46*b71edfaaSTobias Hieta element_type=_element_types[element_type_str], 47*b71edfaaSTobias Hieta ) 484c97745bSMircea Trofin 494c97745bSMircea Trofin 504c97745bSMircea Trofinclass TensorValue: 514c97745bSMircea Trofin def __init__(self, spec: TensorSpec, buffer: bytes): 524c97745bSMircea Trofin self._spec = spec 534c97745bSMircea Trofin self._buffer = buffer 54*b71edfaaSTobias Hieta self._view = ctypes.cast(self._buffer, ctypes.POINTER(self._spec.element_type)) 554c97745bSMircea Trofin self._len = math.prod(self._spec.shape) 564c97745bSMircea Trofin 574c97745bSMircea Trofin def spec(self) -> TensorSpec: 584c97745bSMircea Trofin return self._spec 594c97745bSMircea Trofin 604c97745bSMircea Trofin def __len__(self) -> int: 614c97745bSMircea Trofin return self._len 624c97745bSMircea Trofin 634c97745bSMircea Trofin def __getitem__(self, index): 644c97745bSMircea Trofin if index < 0 or index >= self._len: 65*b71edfaaSTobias Hieta raise IndexError(f"Index {index} out of range [0..{self._len})") 664c97745bSMircea Trofin return self._view[index] 674c97745bSMircea Trofin 684c97745bSMircea Trofin 69954cf9a7SMircea Trofindef read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue: 704c97745bSMircea Trofin size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type) 714c97745bSMircea Trofin data = fs.read(size) 724c97745bSMircea Trofin return TensorValue(ts, data) 734c97745bSMircea Trofin 744c97745bSMircea Trofin 754c97745bSMircea Trofindef pretty_print_tensor_value(tv: TensorValue): 764c97745bSMircea Trofin print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}') 774c97745bSMircea Trofin 78d62cdfadSMircea Trofin 79954cf9a7SMircea Trofindef read_header(f: io.BufferedReader): 804c97745bSMircea Trofin header = json.loads(f.readline()) 81*b71edfaaSTobias Hieta tensor_specs = [TensorSpec.from_dict(ts) for ts in header["features"]] 82*b71edfaaSTobias Hieta score_spec = TensorSpec.from_dict(header["score"]) if "score" in header else None 83*b71edfaaSTobias Hieta advice_spec = TensorSpec.from_dict(header["advice"]) if "advice" in header else None 84954cf9a7SMircea Trofin return tensor_specs, score_spec, advice_spec 85954cf9a7SMircea Trofin 86954cf9a7SMircea Trofin 87*b71edfaaSTobias Hietadef read_one_observation( 88*b71edfaaSTobias Hieta context: Optional[str], 89*b71edfaaSTobias Hieta event_str: str, 90*b71edfaaSTobias Hieta f: io.BufferedReader, 91*b71edfaaSTobias Hieta tensor_specs: List[TensorSpec], 92*b71edfaaSTobias Hieta score_spec: Optional[TensorSpec], 93*b71edfaaSTobias Hieta): 944c97745bSMircea Trofin event = json.loads(event_str) 95*b71edfaaSTobias Hieta if "context" in event: 96*b71edfaaSTobias Hieta context = event["context"] 97954cf9a7SMircea Trofin event = json.loads(f.readline()) 98*b71edfaaSTobias Hieta observation_id = int(event["observation"]) 994c97745bSMircea Trofin features = [] 1004c97745bSMircea Trofin for ts in tensor_specs: 1014c97745bSMircea Trofin features.append(read_tensor(f, ts)) 1024c97745bSMircea Trofin f.readline() 1034c97745bSMircea Trofin score = None 1044c97745bSMircea Trofin if score_spec is not None: 1054c97745bSMircea Trofin score_header = json.loads(f.readline()) 106*b71edfaaSTobias Hieta assert int(score_header["outcome"]) == observation_id 1074c97745bSMircea Trofin score = read_tensor(f, score_spec) 1084c97745bSMircea Trofin f.readline() 109954cf9a7SMircea Trofin return context, observation_id, features, score 110954cf9a7SMircea Trofin 111954cf9a7SMircea Trofin 112954cf9a7SMircea Trofindef read_stream(fname: str): 113*b71edfaaSTobias Hieta with io.BufferedReader(io.FileIO(fname, "rb")) as f: 114954cf9a7SMircea Trofin tensor_specs, score_spec, _ = read_header(f) 115954cf9a7SMircea Trofin context = None 116d62cdfadSMircea Trofin while True: 117d62cdfadSMircea Trofin event_str = f.readline() 118d62cdfadSMircea Trofin if not event_str: 119d62cdfadSMircea Trofin break 120954cf9a7SMircea Trofin context, observation_id, features, score = read_one_observation( 121*b71edfaaSTobias Hieta context, event_str, f, tensor_specs, score_spec 122*b71edfaaSTobias Hieta ) 1234c97745bSMircea Trofin yield context, observation_id, features, score 1244c97745bSMircea Trofin 1254c97745bSMircea Trofin 1264c97745bSMircea Trofindef main(args): 1274c97745bSMircea Trofin last_context = None 1284c97745bSMircea Trofin for ctx, obs_id, features, score in read_stream(args[1]): 1294c97745bSMircea Trofin if last_context != ctx: 130*b71edfaaSTobias Hieta print(f"context: {ctx}") 1314c97745bSMircea Trofin last_context = ctx 132*b71edfaaSTobias Hieta print(f"observation: {obs_id}") 1334c97745bSMircea Trofin for fv in features: 1344c97745bSMircea Trofin pretty_print_tensor_value(fv) 1354c97745bSMircea Trofin if score: 1364c97745bSMircea Trofin pretty_print_tensor_value(score) 1374c97745bSMircea Trofin 1384c97745bSMircea Trofin 139*b71edfaaSTobias Hietaif __name__ == "__main__": 1404c97745bSMircea Trofin main(sys.argv) 141