1"""Reader for training log. 2 3See lib/Analysis/TrainingLogger.cpp for a description of the format. 4""" 5import ctypes 6import dataclasses 7import io 8import json 9import math 10import sys 11from typing import List, Optional 12 13_element_types = { 14 "float": ctypes.c_float, 15 "double": ctypes.c_double, 16 "int8_t": ctypes.c_int8, 17 "uint8_t": ctypes.c_uint8, 18 "int16_t": ctypes.c_int16, 19 "uint16_t": ctypes.c_uint16, 20 "int32_t": ctypes.c_int32, 21 "uint32_t": ctypes.c_uint32, 22 "int64_t": ctypes.c_int64, 23 "uint64_t": ctypes.c_uint64, 24} 25 26 27@dataclasses.dataclass(frozen=True) 28class TensorSpec: 29 name: str 30 port: int 31 shape: List[int] 32 element_type: type 33 34 @staticmethod 35 def from_dict(d: dict): 36 name = d["name"] 37 port = d["port"] 38 shape = [int(e) for e in d["shape"]] 39 element_type_str = d["type"] 40 if element_type_str not in _element_types: 41 raise ValueError(f"uknown type: {element_type_str}") 42 return TensorSpec( 43 name=name, 44 port=port, 45 shape=shape, 46 element_type=_element_types[element_type_str], 47 ) 48 49 50class TensorValue: 51 def __init__(self, spec: TensorSpec, buffer: bytes): 52 self._spec = spec 53 self._buffer = buffer 54 self._view = ctypes.cast(self._buffer, ctypes.POINTER(self._spec.element_type)) 55 self._len = math.prod(self._spec.shape) 56 57 def spec(self) -> TensorSpec: 58 return self._spec 59 60 def __len__(self) -> int: 61 return self._len 62 63 def __getitem__(self, index): 64 if index < 0 or index >= self._len: 65 raise IndexError(f"Index {index} out of range [0..{self._len})") 66 return self._view[index] 67 68 69def read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue: 70 size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type) 71 data = fs.read(size) 72 return TensorValue(ts, data) 73 74 75def pretty_print_tensor_value(tv: TensorValue): 76 print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}') 77 78 79def read_header(f: io.BufferedReader): 80 header = json.loads(f.readline()) 81 tensor_specs = [TensorSpec.from_dict(ts) for ts in header["features"]] 82 score_spec = TensorSpec.from_dict(header["score"]) if "score" in header else None 83 advice_spec = TensorSpec.from_dict(header["advice"]) if "advice" in header else None 84 return tensor_specs, score_spec, advice_spec 85 86 87def read_one_observation( 88 context: Optional[str], 89 event_str: str, 90 f: io.BufferedReader, 91 tensor_specs: List[TensorSpec], 92 score_spec: Optional[TensorSpec], 93): 94 event = json.loads(event_str) 95 if "context" in event: 96 context = event["context"] 97 event = json.loads(f.readline()) 98 observation_id = int(event["observation"]) 99 features = [] 100 for ts in tensor_specs: 101 features.append(read_tensor(f, ts)) 102 f.readline() 103 score = None 104 if score_spec is not None: 105 score_header = json.loads(f.readline()) 106 assert int(score_header["outcome"]) == observation_id 107 score = read_tensor(f, score_spec) 108 f.readline() 109 return context, observation_id, features, score 110 111 112def read_stream(fname: str): 113 with io.BufferedReader(io.FileIO(fname, "rb")) as f: 114 tensor_specs, score_spec, _ = read_header(f) 115 context = None 116 while True: 117 event_str = f.readline() 118 if not event_str: 119 break 120 context, observation_id, features, score = read_one_observation( 121 context, event_str, f, tensor_specs, score_spec 122 ) 123 yield context, observation_id, features, score 124 125 126def main(args): 127 last_context = None 128 for ctx, obs_id, features, score in read_stream(args[1]): 129 if last_context != ctx: 130 print(f"context: {ctx}") 131 last_context = ctx 132 print(f"observation: {obs_id}") 133 for fv in features: 134 pretty_print_tensor_value(fv) 135 if score: 136 pretty_print_tensor_value(score) 137 138 139if __name__ == "__main__": 140 main(sys.argv) 141