xref: /openbsd-src/gnu/llvm/llvm/lib/Analysis/models/log_reader.py (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
1"""Reader for training log.
2
3See lib/Analysis/TrainingLogger.cpp for a description of the format.
4"""
5import ctypes
6import dataclasses
7import json
8import math
9import sys
10import typing
11
12_element_types = {
13    'float': ctypes.c_float,
14    'double': ctypes.c_double,
15    'int8_t': ctypes.c_int8,
16    'uint8_t': ctypes.c_uint8,
17    'int16_t': ctypes.c_int16,
18    'uint16_t': ctypes.c_uint16,
19    'int32_t': ctypes.c_int32,
20    'uint32_t': ctypes.c_uint32,
21    'int64_t': ctypes.c_int64,
22    'uint64_t': ctypes.c_uint64
23}
24
25
26@dataclasses.dataclass(frozen=True)
27class TensorSpec:
28  name: str
29  port: int
30  shape: list[int]
31  element_type: type
32
33  @staticmethod
34  def from_dict(d: dict):
35    name = d['name']
36    port = d['port']
37    shape = [int(e) for e in d['shape']]
38    element_type_str = d['type']
39    if element_type_str not in _element_types:
40      raise ValueError(f'uknown type: {element_type_str}')
41    return TensorSpec(
42        name=name,
43        port=port,
44        shape=shape,
45        element_type=_element_types[element_type_str])
46
47
48class TensorValue:
49
50  def __init__(self, spec: TensorSpec, buffer: bytes):
51    self._spec = spec
52    self._buffer = buffer
53    self._view = ctypes.cast(self._buffer,
54                             ctypes.POINTER(self._spec.element_type))
55    self._len = math.prod(self._spec.shape)
56
57  def spec(self) -> TensorSpec:
58    return self._spec
59
60  def __len__(self) -> int:
61    return self._len
62
63  def __getitem__(self, index):
64    if index < 0 or index >= self._len:
65      raise IndexError(f'Index {index} out of range [0..{self._len})')
66    return self._view[index]
67
68
69def read_tensor(fs: typing.BinaryIO, ts: TensorSpec) -> TensorValue:
70  size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type)
71  data = fs.read(size)
72  return TensorValue(ts, data)
73
74
75def pretty_print_tensor_value(tv: TensorValue):
76  print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}')
77
78
79def read_stream(fname: str):
80  with open(fname, 'rb') as f:
81    header = json.loads(f.readline())
82    tensor_specs = [TensorSpec.from_dict(ts) for ts in header['features']]
83    score_spec = TensorSpec.from_dict(
84        header['score']) if 'score' in header else None
85    context = None
86    while event_str := f.readline():
87      event = json.loads(event_str)
88      if 'context' in event:
89        context = event['context']
90        continue
91      observation_id = int(event['observation'])
92      features = []
93      for ts in tensor_specs:
94        features.append(read_tensor(f, ts))
95      f.readline()
96      score = None
97      if score_spec is not None:
98        score_header = json.loads(f.readline())
99        assert int(score_header['outcome']) == observation_id
100        score = read_tensor(f, score_spec)
101        f.readline()
102      yield context, observation_id, features, score
103
104
105def main(args):
106  last_context = None
107  for ctx, obs_id, features, score in read_stream(args[1]):
108    if last_context != ctx:
109      print(f'context: {ctx}')
110      last_context = ctx
111    print(f'observation: {obs_id}')
112    for fv in features:
113      pretty_print_tensor_value(fv)
114    if score:
115      pretty_print_tensor_value(score)
116
117
118if __name__ == '__main__':
119  main(sys.argv)
120