xref: /llvm-project/libcxx/utils/generate_iwyu_mapping.py (revision e78f53d1e8d622ee4b12dbc2ac8252b4805d5719)
1#!/usr/bin/env python
2
3import argparse
4import libcxx.header_information
5import os
6import pathlib
7import re
8import sys
9import typing
10
11def IWYU_mapping(header: str) -> typing.Optional[typing.List[str]]:
12    ignore = [
13        "__cxx03/.+",
14        "__debug_utils/.+",
15        "__fwd/get[.]h",
16        "__pstl/.+",
17        "__support/.+",
18        "__utility/private_constructor_tag.h",
19    ]
20    if any(re.match(pattern, header) for pattern in ignore):
21        return None
22    elif header == "__bits":
23        return ["bits"]
24    elif header in ("__bit_reference", "__fwd/bit_reference.h"):
25        return ["bitset", "vector"]
26    elif re.match("__configuration/.+", header) or header == "__config":
27        return ["version"]
28    elif header == "__hash_table":
29        return ["unordered_map", "unordered_set"]
30    elif header == "__locale":
31        return ["locale"]
32    elif re.match("__locale_dir/.+", header):
33        return ["locale"]
34    elif re.match("__math/.+", header):
35        return ["cmath"]
36    elif header == "__node_handle":
37        return ["map", "set", "unordered_map", "unordered_set"]
38    elif header == "__split_buffer":
39        return ["deque", "vector"]
40    elif re.match("(__thread/support[.]h)|(__thread/support/.+)", header):
41        return ["atomic", "mutex", "semaphore", "thread"]
42    elif header == "__tree":
43        return ["map", "set"]
44    elif header == "__fwd/byte.h":
45        return ["cstddef"]
46    elif header == "__fwd/pair.h":
47        return ["utility"]
48    elif header == "__fwd/subrange.h":
49        return ["ranges"]
50    elif re.match("__fwd/(fstream|ios|istream|ostream|sstream|streambuf)[.]h", header):
51        return ["iosfwd"]
52    # Handle remaining forward declaration headers
53    elif re.match("__fwd/(.+)[.]h", header):
54        return [re.match("__fwd/(.+)[.]h", header).group(1)]
55    # Handle detail headers for things like <__algorithm/foo.h>
56    elif re.match("__(.+?)/.+", header):
57        return [re.match("__(.+?)/.+", header).group(1)]
58    else:
59        return None
60
61
62def main(argv: typing.List[str]):
63    parser = argparse.ArgumentParser()
64    parser.add_argument(
65        "-o",
66        help="File to output the IWYU mappings into",
67        type=argparse.FileType("w"),
68        required=True,
69        dest="output",
70    )
71    args = parser.parse_args(argv)
72
73    mappings = []  # Pairs of (header, public_header)
74    for header in libcxx.header_information.all_headers:
75        public_headers = IWYU_mapping(str(header))
76        if public_headers is not None:
77            mappings.extend((header, public) for public in public_headers)
78
79    # Validate that we only have valid public header names -- otherwise the mapping above
80    # needs to be updated.
81    for header, public in mappings:
82        if public not in libcxx.header_information.public_headers:
83            raise RuntimeError(f"{header}: Header {public} is not a valid header")
84
85    args.output.write("[\n")
86    for header, public in sorted(mappings):
87        args.output.write(
88            f'  {{ include: [ "<{header}>", "private", "<{public}>", "public" ] }},\n'
89        )
90    args.output.write("]\n")
91
92if __name__ == "__main__":
93    main(sys.argv[1:])
94