xref: /llvm-project/llvm/utils/sort_includes.py (revision b71edfaa4ec3c998aadb35255ce2f60bba2940b0)
1#!/usr/bin/env python
2
3"""Script to sort the top-most block of #include lines.
4
5Assumes the LLVM coding conventions.
6
7Currently, this script only bothers sorting the llvm/... headers. Patches
8welcome for more functionality, and sorting other header groups.
9"""
10
11import argparse
12import os
13
14
15def sort_includes(f):
16    """Sort the #include lines of a specific file."""
17
18    # Skip files which are under INPUTS trees or test trees.
19    if "INPUTS/" in f.name or "test/" in f.name:
20        return
21
22    ext = os.path.splitext(f.name)[1]
23    if ext not in [".cpp", ".c", ".h", ".inc", ".def"]:
24        return
25
26    lines = f.readlines()
27    look_for_api_header = ext in [".cpp", ".c"]
28    found_headers = False
29    headers_begin = 0
30    headers_end = 0
31    api_headers = []
32    local_headers = []
33    subproject_headers = []
34    llvm_headers = []
35    system_headers = []
36    for (i, l) in enumerate(lines):
37        if l.strip() == "":
38            continue
39        if l.startswith("#include"):
40            if not found_headers:
41                headers_begin = i
42                found_headers = True
43            headers_end = i
44            header = l[len("#include") :].lstrip()
45            if look_for_api_header and header.startswith('"'):
46                api_headers.append(header)
47                look_for_api_header = False
48                continue
49            if (
50                header.startswith("<")
51                or header.startswith('"gtest/')
52                or header.startswith('"isl/')
53                or header.startswith('"json/')
54            ):
55                system_headers.append(header)
56                continue
57            if (
58                header.startswith('"clang/')
59                or header.startswith('"clang-c/')
60                or header.startswith('"polly/')
61            ):
62                subproject_headers.append(header)
63                continue
64            if header.startswith('"llvm/') or header.startswith('"llvm-c/'):
65                llvm_headers.append(header)
66                continue
67            local_headers.append(header)
68            continue
69
70        # Only allow comments and #defines prior to any includes. If either are
71        # mixed with includes, the order might be sensitive.
72        if found_headers:
73            break
74        if l.startswith("//") or l.startswith("#define") or l.startswith("#ifndef"):
75            continue
76        break
77    if not found_headers:
78        return
79
80    local_headers = sorted(set(local_headers))
81    subproject_headers = sorted(set(subproject_headers))
82    llvm_headers = sorted(set(llvm_headers))
83    system_headers = sorted(set(system_headers))
84    headers = (
85        api_headers + local_headers + subproject_headers + llvm_headers + system_headers
86    )
87    header_lines = ["#include " + h for h in headers]
88    lines = lines[:headers_begin] + header_lines + lines[headers_end + 1 :]
89
90    f.seek(0)
91    f.truncate()
92    f.writelines(lines)
93
94
95def main():
96    parser = argparse.ArgumentParser(description=__doc__)
97    parser.add_argument(
98        "files",
99        nargs="+",
100        type=argparse.FileType("r+"),
101        help="the source files to sort includes within",
102    )
103    args = parser.parse_args()
104    for f in args.files:
105        sort_includes(f)
106
107
108if __name__ == "__main__":
109    main()
110