xref: /llvm-project/llvm/utils/mlgo-utils/mlgo/corpus/combine_training_corpus_lib.py (revision 39f1ca522b023e77c4dca6332d8b9c6366d0eab9)
1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""Library for combining training corpora."""
5
6import os
7import json
8import glob
9import logging
10
11_FILE_NAME = "corpus_description.json"
12
13
14def combine_corpus(root_dir: str) -> None:
15    module_names = []
16    output_corpus_description = {}
17
18    corpus_description_glob = os.path.join(root_dir, "*/" + _FILE_NAME)
19    for corpus_description_path in glob.glob(corpus_description_glob):
20        logging.info("processing %s", corpus_description_path)
21
22        with open(corpus_description_path, encoding="utf-8") as f:
23            corpus_description = json.load(f)
24            sub_dir = os.path.basename(os.path.dirname(corpus_description_path))
25            module_names.extend(
26                [os.path.join(sub_dir, name) for name in corpus_description["modules"]]
27            )
28            del corpus_description["modules"]
29            if len(output_corpus_description) == 0:
30                output_corpus_description = corpus_description
31            elif corpus_description != output_corpus_description:
32                raise ValueError("Input corpora differ by more than modules.")
33
34    output_corpus_description["modules"] = module_names
35
36    with open(os.path.join(root_dir, _FILE_NAME), "w") as f:
37        json.dump(output_corpus_description, f, indent=2)
38