xref: /llvm-project/llvm/lib/Analysis/models/log_reader.py (revision b71edfaa4ec3c998aadb35255ce2f60bba2940b0)
1"""Reader for training log.
2
3See lib/Analysis/TrainingLogger.cpp for a description of the format.
4"""
5import ctypes
6import dataclasses
7import io
8import json
9import math
10import sys
11from typing import List, Optional
12
13_element_types = {
14    "float": ctypes.c_float,
15    "double": ctypes.c_double,
16    "int8_t": ctypes.c_int8,
17    "uint8_t": ctypes.c_uint8,
18    "int16_t": ctypes.c_int16,
19    "uint16_t": ctypes.c_uint16,
20    "int32_t": ctypes.c_int32,
21    "uint32_t": ctypes.c_uint32,
22    "int64_t": ctypes.c_int64,
23    "uint64_t": ctypes.c_uint64,
24}
25
26
27@dataclasses.dataclass(frozen=True)
28class TensorSpec:
29    name: str
30    port: int
31    shape: List[int]
32    element_type: type
33
34    @staticmethod
35    def from_dict(d: dict):
36        name = d["name"]
37        port = d["port"]
38        shape = [int(e) for e in d["shape"]]
39        element_type_str = d["type"]
40        if element_type_str not in _element_types:
41            raise ValueError(f"uknown type: {element_type_str}")
42        return TensorSpec(
43            name=name,
44            port=port,
45            shape=shape,
46            element_type=_element_types[element_type_str],
47        )
48
49
50class TensorValue:
51    def __init__(self, spec: TensorSpec, buffer: bytes):
52        self._spec = spec
53        self._buffer = buffer
54        self._view = ctypes.cast(self._buffer, ctypes.POINTER(self._spec.element_type))
55        self._len = math.prod(self._spec.shape)
56
57    def spec(self) -> TensorSpec:
58        return self._spec
59
60    def __len__(self) -> int:
61        return self._len
62
63    def __getitem__(self, index):
64        if index < 0 or index >= self._len:
65            raise IndexError(f"Index {index} out of range [0..{self._len})")
66        return self._view[index]
67
68
69def read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue:
70    size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type)
71    data = fs.read(size)
72    return TensorValue(ts, data)
73
74
75def pretty_print_tensor_value(tv: TensorValue):
76    print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}')
77
78
79def read_header(f: io.BufferedReader):
80    header = json.loads(f.readline())
81    tensor_specs = [TensorSpec.from_dict(ts) for ts in header["features"]]
82    score_spec = TensorSpec.from_dict(header["score"]) if "score" in header else None
83    advice_spec = TensorSpec.from_dict(header["advice"]) if "advice" in header else None
84    return tensor_specs, score_spec, advice_spec
85
86
87def read_one_observation(
88    context: Optional[str],
89    event_str: str,
90    f: io.BufferedReader,
91    tensor_specs: List[TensorSpec],
92    score_spec: Optional[TensorSpec],
93):
94    event = json.loads(event_str)
95    if "context" in event:
96        context = event["context"]
97        event = json.loads(f.readline())
98    observation_id = int(event["observation"])
99    features = []
100    for ts in tensor_specs:
101        features.append(read_tensor(f, ts))
102    f.readline()
103    score = None
104    if score_spec is not None:
105        score_header = json.loads(f.readline())
106        assert int(score_header["outcome"]) == observation_id
107        score = read_tensor(f, score_spec)
108        f.readline()
109    return context, observation_id, features, score
110
111
112def read_stream(fname: str):
113    with io.BufferedReader(io.FileIO(fname, "rb")) as f:
114        tensor_specs, score_spec, _ = read_header(f)
115        context = None
116        while True:
117            event_str = f.readline()
118            if not event_str:
119                break
120            context, observation_id, features, score = read_one_observation(
121                context, event_str, f, tensor_specs, score_spec
122            )
123            yield context, observation_id, features, score
124
125
126def main(args):
127    last_context = None
128    for ctx, obs_id, features, score in read_stream(args[1]):
129        if last_context != ctx:
130            print(f"context: {ctx}")
131            last_context = ctx
132        print(f"observation: {obs_id}")
133        for fv in features:
134            pretty_print_tensor_value(fv)
135        if score:
136            pretty_print_tensor_value(score)
137
138
139if __name__ == "__main__":
140    main(sys.argv)
141