"""Reader for training log. See lib/Analysis/TrainingLogger.cpp for a description of the format. """ import ctypes import dataclasses import io import json import math import sys from typing import List, Optional _element_types = { "float": ctypes.c_float, "double": ctypes.c_double, "int8_t": ctypes.c_int8, "uint8_t": ctypes.c_uint8, "int16_t": ctypes.c_int16, "uint16_t": ctypes.c_uint16, "int32_t": ctypes.c_int32, "uint32_t": ctypes.c_uint32, "int64_t": ctypes.c_int64, "uint64_t": ctypes.c_uint64, } @dataclasses.dataclass(frozen=True) class TensorSpec: name: str port: int shape: List[int] element_type: type @staticmethod def from_dict(d: dict): name = d["name"] port = d["port"] shape = [int(e) for e in d["shape"]] element_type_str = d["type"] if element_type_str not in _element_types: raise ValueError(f"uknown type: {element_type_str}") return TensorSpec( name=name, port=port, shape=shape, element_type=_element_types[element_type_str], ) class TensorValue: def __init__(self, spec: TensorSpec, buffer: bytes): self._spec = spec self._buffer = buffer self._view = ctypes.cast(self._buffer, ctypes.POINTER(self._spec.element_type)) self._len = math.prod(self._spec.shape) def spec(self) -> TensorSpec: return self._spec def __len__(self) -> int: return self._len def __getitem__(self, index): if index < 0 or index >= self._len: raise IndexError(f"Index {index} out of range [0..{self._len})") return self._view[index] def read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue: size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type) data = fs.read(size) return TensorValue(ts, data) def pretty_print_tensor_value(tv: TensorValue): print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}') def read_header(f: io.BufferedReader): header = json.loads(f.readline()) tensor_specs = [TensorSpec.from_dict(ts) for ts in header["features"]] score_spec = TensorSpec.from_dict(header["score"]) if "score" in header else None advice_spec = TensorSpec.from_dict(header["advice"]) if "advice" in header else None return tensor_specs, score_spec, advice_spec def read_one_observation( context: Optional[str], event_str: str, f: io.BufferedReader, tensor_specs: List[TensorSpec], score_spec: Optional[TensorSpec], ): event = json.loads(event_str) if "context" in event: context = event["context"] event = json.loads(f.readline()) observation_id = int(event["observation"]) features = [] for ts in tensor_specs: features.append(read_tensor(f, ts)) f.readline() score = None if score_spec is not None: score_header = json.loads(f.readline()) assert int(score_header["outcome"]) == observation_id score = read_tensor(f, score_spec) f.readline() return context, observation_id, features, score def read_stream(fname: str): with io.BufferedReader(io.FileIO(fname, "rb")) as f: tensor_specs, score_spec, _ = read_header(f) context = None while True: event_str = f.readline() if not event_str: break context, observation_id, features, score = read_one_observation( context, event_str, f, tensor_specs, score_spec ) yield context, observation_id, features, score def main(args): last_context = None for ctx, obs_id, features, score in read_stream(args[1]): if last_context != ctx: print(f"context: {ctx}") last_context = ctx print(f"observation: {obs_id}") for fv in features: pretty_print_tensor_value(fv) if score: pretty_print_tensor_value(score) if __name__ == "__main__": main(sys.argv)