xref: /llvm-project/llvm/lib/Analysis/models/log_reader.py (revision b71edfaa4ec3c998aadb35255ce2f60bba2940b0)
14c97745bSMircea Trofin"""Reader for training log.
24c97745bSMircea Trofin
34c97745bSMircea TrofinSee lib/Analysis/TrainingLogger.cpp for a description of the format.
44c97745bSMircea Trofin"""
54c97745bSMircea Trofinimport ctypes
64c97745bSMircea Trofinimport dataclasses
7954cf9a7SMircea Trofinimport io
84c97745bSMircea Trofinimport json
94c97745bSMircea Trofinimport math
104c97745bSMircea Trofinimport sys
11b72e893dSMircea Trofinfrom typing import List, Optional
124c97745bSMircea Trofin
134c97745bSMircea Trofin_element_types = {
14*b71edfaaSTobias Hieta    "float": ctypes.c_float,
15*b71edfaaSTobias Hieta    "double": ctypes.c_double,
16*b71edfaaSTobias Hieta    "int8_t": ctypes.c_int8,
17*b71edfaaSTobias Hieta    "uint8_t": ctypes.c_uint8,
18*b71edfaaSTobias Hieta    "int16_t": ctypes.c_int16,
19*b71edfaaSTobias Hieta    "uint16_t": ctypes.c_uint16,
20*b71edfaaSTobias Hieta    "int32_t": ctypes.c_int32,
21*b71edfaaSTobias Hieta    "uint32_t": ctypes.c_uint32,
22*b71edfaaSTobias Hieta    "int64_t": ctypes.c_int64,
23*b71edfaaSTobias Hieta    "uint64_t": ctypes.c_uint64,
244c97745bSMircea Trofin}
254c97745bSMircea Trofin
264c97745bSMircea Trofin
274c97745bSMircea Trofin@dataclasses.dataclass(frozen=True)
284c97745bSMircea Trofinclass TensorSpec:
294c97745bSMircea Trofin    name: str
304c97745bSMircea Trofin    port: int
31b72e893dSMircea Trofin    shape: List[int]
324c97745bSMircea Trofin    element_type: type
334c97745bSMircea Trofin
344c97745bSMircea Trofin    @staticmethod
354c97745bSMircea Trofin    def from_dict(d: dict):
36*b71edfaaSTobias Hieta        name = d["name"]
37*b71edfaaSTobias Hieta        port = d["port"]
38*b71edfaaSTobias Hieta        shape = [int(e) for e in d["shape"]]
39*b71edfaaSTobias Hieta        element_type_str = d["type"]
404c97745bSMircea Trofin        if element_type_str not in _element_types:
41*b71edfaaSTobias Hieta            raise ValueError(f"uknown type: {element_type_str}")
424c97745bSMircea Trofin        return TensorSpec(
434c97745bSMircea Trofin            name=name,
444c97745bSMircea Trofin            port=port,
454c97745bSMircea Trofin            shape=shape,
46*b71edfaaSTobias Hieta            element_type=_element_types[element_type_str],
47*b71edfaaSTobias Hieta        )
484c97745bSMircea Trofin
494c97745bSMircea Trofin
504c97745bSMircea Trofinclass TensorValue:
514c97745bSMircea Trofin    def __init__(self, spec: TensorSpec, buffer: bytes):
524c97745bSMircea Trofin        self._spec = spec
534c97745bSMircea Trofin        self._buffer = buffer
54*b71edfaaSTobias Hieta        self._view = ctypes.cast(self._buffer, ctypes.POINTER(self._spec.element_type))
554c97745bSMircea Trofin        self._len = math.prod(self._spec.shape)
564c97745bSMircea Trofin
574c97745bSMircea Trofin    def spec(self) -> TensorSpec:
584c97745bSMircea Trofin        return self._spec
594c97745bSMircea Trofin
604c97745bSMircea Trofin    def __len__(self) -> int:
614c97745bSMircea Trofin        return self._len
624c97745bSMircea Trofin
634c97745bSMircea Trofin    def __getitem__(self, index):
644c97745bSMircea Trofin        if index < 0 or index >= self._len:
65*b71edfaaSTobias Hieta            raise IndexError(f"Index {index} out of range [0..{self._len})")
664c97745bSMircea Trofin        return self._view[index]
674c97745bSMircea Trofin
684c97745bSMircea Trofin
69954cf9a7SMircea Trofindef read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue:
704c97745bSMircea Trofin    size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type)
714c97745bSMircea Trofin    data = fs.read(size)
724c97745bSMircea Trofin    return TensorValue(ts, data)
734c97745bSMircea Trofin
744c97745bSMircea Trofin
754c97745bSMircea Trofindef pretty_print_tensor_value(tv: TensorValue):
764c97745bSMircea Trofin    print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}')
774c97745bSMircea Trofin
78d62cdfadSMircea Trofin
79954cf9a7SMircea Trofindef read_header(f: io.BufferedReader):
804c97745bSMircea Trofin    header = json.loads(f.readline())
81*b71edfaaSTobias Hieta    tensor_specs = [TensorSpec.from_dict(ts) for ts in header["features"]]
82*b71edfaaSTobias Hieta    score_spec = TensorSpec.from_dict(header["score"]) if "score" in header else None
83*b71edfaaSTobias Hieta    advice_spec = TensorSpec.from_dict(header["advice"]) if "advice" in header else None
84954cf9a7SMircea Trofin    return tensor_specs, score_spec, advice_spec
85954cf9a7SMircea Trofin
86954cf9a7SMircea Trofin
87*b71edfaaSTobias Hietadef read_one_observation(
88*b71edfaaSTobias Hieta    context: Optional[str],
89*b71edfaaSTobias Hieta    event_str: str,
90*b71edfaaSTobias Hieta    f: io.BufferedReader,
91*b71edfaaSTobias Hieta    tensor_specs: List[TensorSpec],
92*b71edfaaSTobias Hieta    score_spec: Optional[TensorSpec],
93*b71edfaaSTobias Hieta):
944c97745bSMircea Trofin    event = json.loads(event_str)
95*b71edfaaSTobias Hieta    if "context" in event:
96*b71edfaaSTobias Hieta        context = event["context"]
97954cf9a7SMircea Trofin        event = json.loads(f.readline())
98*b71edfaaSTobias Hieta    observation_id = int(event["observation"])
994c97745bSMircea Trofin    features = []
1004c97745bSMircea Trofin    for ts in tensor_specs:
1014c97745bSMircea Trofin        features.append(read_tensor(f, ts))
1024c97745bSMircea Trofin    f.readline()
1034c97745bSMircea Trofin    score = None
1044c97745bSMircea Trofin    if score_spec is not None:
1054c97745bSMircea Trofin        score_header = json.loads(f.readline())
106*b71edfaaSTobias Hieta        assert int(score_header["outcome"]) == observation_id
1074c97745bSMircea Trofin        score = read_tensor(f, score_spec)
1084c97745bSMircea Trofin        f.readline()
109954cf9a7SMircea Trofin    return context, observation_id, features, score
110954cf9a7SMircea Trofin
111954cf9a7SMircea Trofin
112954cf9a7SMircea Trofindef read_stream(fname: str):
113*b71edfaaSTobias Hieta    with io.BufferedReader(io.FileIO(fname, "rb")) as f:
114954cf9a7SMircea Trofin        tensor_specs, score_spec, _ = read_header(f)
115954cf9a7SMircea Trofin        context = None
116d62cdfadSMircea Trofin        while True:
117d62cdfadSMircea Trofin            event_str = f.readline()
118d62cdfadSMircea Trofin            if not event_str:
119d62cdfadSMircea Trofin                break
120954cf9a7SMircea Trofin            context, observation_id, features, score = read_one_observation(
121*b71edfaaSTobias Hieta                context, event_str, f, tensor_specs, score_spec
122*b71edfaaSTobias Hieta            )
1234c97745bSMircea Trofin            yield context, observation_id, features, score
1244c97745bSMircea Trofin
1254c97745bSMircea Trofin
1264c97745bSMircea Trofindef main(args):
1274c97745bSMircea Trofin    last_context = None
1284c97745bSMircea Trofin    for ctx, obs_id, features, score in read_stream(args[1]):
1294c97745bSMircea Trofin        if last_context != ctx:
130*b71edfaaSTobias Hieta            print(f"context: {ctx}")
1314c97745bSMircea Trofin            last_context = ctx
132*b71edfaaSTobias Hieta        print(f"observation: {obs_id}")
1334c97745bSMircea Trofin        for fv in features:
1344c97745bSMircea Trofin            pretty_print_tensor_value(fv)
1354c97745bSMircea Trofin        if score:
1364c97745bSMircea Trofin            pretty_print_tensor_value(score)
1374c97745bSMircea Trofin
1384c97745bSMircea Trofin
139*b71edfaaSTobias Hietaif __name__ == "__main__":
1404c97745bSMircea Trofin    main(sys.argv)
141