xref: /netbsd-src/external/bsd/zstd/dist/contrib/freestanding_lib/freestanding.py (revision 3117ece4fc4a4ca4489ba793710b60b0d26bab6c)
1*3117ece4Schristos#!/usr/bin/env python3
2*3117ece4Schristos# ################################################################
3*3117ece4Schristos# Copyright (c) Meta Platforms, Inc. and affiliates.
4*3117ece4Schristos# All rights reserved.
5*3117ece4Schristos#
6*3117ece4Schristos# This source code is licensed under both the BSD-style license (found in the
7*3117ece4Schristos# LICENSE file in the root directory of this source tree) and the GPLv2 (found
8*3117ece4Schristos# in the COPYING file in the root directory of this source tree).
9*3117ece4Schristos# You may select, at your option, one of the above-listed licenses.
10*3117ece4Schristos# ##########################################################################
11*3117ece4Schristos
12*3117ece4Schristosimport argparse
13*3117ece4Schristosimport contextlib
14*3117ece4Schristosimport os
15*3117ece4Schristosimport re
16*3117ece4Schristosimport shutil
17*3117ece4Schristosimport sys
18*3117ece4Schristosfrom typing import Optional
19*3117ece4Schristos
20*3117ece4Schristos
21*3117ece4SchristosINCLUDED_SUBDIRS = ["common", "compress", "decompress"]
22*3117ece4Schristos
23*3117ece4SchristosSKIPPED_FILES = [
24*3117ece4Schristos    "common/mem.h",
25*3117ece4Schristos    "common/zstd_deps.h",
26*3117ece4Schristos    "common/pool.c",
27*3117ece4Schristos    "common/pool.h",
28*3117ece4Schristos    "common/threading.c",
29*3117ece4Schristos    "common/threading.h",
30*3117ece4Schristos    "common/zstd_trace.h",
31*3117ece4Schristos    "compress/zstdmt_compress.h",
32*3117ece4Schristos    "compress/zstdmt_compress.c",
33*3117ece4Schristos]
34*3117ece4Schristos
35*3117ece4SchristosXXHASH_FILES = [
36*3117ece4Schristos    "common/xxhash.c",
37*3117ece4Schristos    "common/xxhash.h",
38*3117ece4Schristos]
39*3117ece4Schristos
40*3117ece4Schristos
41*3117ece4Schristosclass FileLines(object):
42*3117ece4Schristos    def __init__(self, filename):
43*3117ece4Schristos        self.filename = filename
44*3117ece4Schristos        with open(self.filename, "r") as f:
45*3117ece4Schristos            self.lines = f.readlines()
46*3117ece4Schristos
47*3117ece4Schristos    def write(self):
48*3117ece4Schristos        with open(self.filename, "w") as f:
49*3117ece4Schristos            f.write("".join(self.lines))
50*3117ece4Schristos
51*3117ece4Schristos
52*3117ece4Schristosclass PartialPreprocessor(object):
53*3117ece4Schristos    """
54*3117ece4Schristos    Looks for simple ifdefs and ifndefs and replaces them.
55*3117ece4Schristos    Handles && and ||.
56*3117ece4Schristos    Has fancy logic to handle translating elifs to ifs.
57*3117ece4Schristos    Only looks for macros in the first part of the expression with no
58*3117ece4Schristos    parens.
59*3117ece4Schristos    Does not handle multi-line macros (only looks in first line).
60*3117ece4Schristos    """
61*3117ece4Schristos    def __init__(self, defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str]):
62*3117ece4Schristos        MACRO_GROUP = r"(?P<macro>[a-zA-Z_][a-zA-Z_0-9]*)"
63*3117ece4Schristos        ELIF_GROUP = r"(?P<elif>el)?"
64*3117ece4Schristos        OP_GROUP = r"(?P<op>&&|\|\|)?"
65*3117ece4Schristos
66*3117ece4Schristos        self._defs = {macro:value for macro, value in defs}
67*3117ece4Schristos        self._replaces = {macro:value for macro, value in replaces}
68*3117ece4Schristos        self._defs.update(self._replaces)
69*3117ece4Schristos        self._undefs = set(undefs)
70*3117ece4Schristos
71*3117ece4Schristos        self._define = re.compile(r"\s*#\s*define")
72*3117ece4Schristos        self._if = re.compile(r"\s*#\s*if")
73*3117ece4Schristos        self._elif = re.compile(r"\s*#\s*(?P<elif>el)if")
74*3117ece4Schristos        self._else = re.compile(r"\s*#\s*(?P<else>else)")
75*3117ece4Schristos        self._endif = re.compile(r"\s*#\s*endif")
76*3117ece4Schristos
77*3117ece4Schristos        self._ifdef = re.compile(fr"\s*#\s*if(?P<not>n)?def {MACRO_GROUP}\s*")
78*3117ece4Schristos        self._if_defined = re.compile(
79*3117ece4Schristos            fr"\s*#\s*{ELIF_GROUP}if\s+(?P<not>!)?\s*defined\s*\(\s*{MACRO_GROUP}\s*\)\s*{OP_GROUP}"
80*3117ece4Schristos        )
81*3117ece4Schristos        self._if_defined_value = re.compile(
82*3117ece4Schristos            fr"\s*#\s*{ELIF_GROUP}if\s+defined\s*\(\s*{MACRO_GROUP}\s*\)\s*"
83*3117ece4Schristos            fr"(?P<op>&&)\s*"
84*3117ece4Schristos            fr"(?P<openp>\()?\s*"
85*3117ece4Schristos            fr"(?P<macro2>[a-zA-Z_][a-zA-Z_0-9]*)\s*"
86*3117ece4Schristos            fr"(?P<cmp>[=><!]+)\s*"
87*3117ece4Schristos            fr"(?P<value>[0-9]*)\s*"
88*3117ece4Schristos            fr"(?P<closep>\))?\s*"
89*3117ece4Schristos        )
90*3117ece4Schristos        self._if_true = re.compile(
91*3117ece4Schristos            fr"\s*#\s*{ELIF_GROUP}if\s+{MACRO_GROUP}\s*{OP_GROUP}"
92*3117ece4Schristos        )
93*3117ece4Schristos
94*3117ece4Schristos        self._c_comment = re.compile(r"/\*.*?\*/")
95*3117ece4Schristos        self._cpp_comment = re.compile(r"//")
96*3117ece4Schristos
97*3117ece4Schristos    def _log(self, *args, **kwargs):
98*3117ece4Schristos        print(*args, **kwargs)
99*3117ece4Schristos
100*3117ece4Schristos    def _strip_comments(self, line):
101*3117ece4Schristos        # First strip c-style comments (may include //)
102*3117ece4Schristos        while True:
103*3117ece4Schristos            m = self._c_comment.search(line)
104*3117ece4Schristos            if m is None:
105*3117ece4Schristos                break
106*3117ece4Schristos            line = line[:m.start()] + line[m.end():]
107*3117ece4Schristos
108*3117ece4Schristos        # Then strip cpp-style comments
109*3117ece4Schristos        m = self._cpp_comment.search(line)
110*3117ece4Schristos        if m is not None:
111*3117ece4Schristos            line = line[:m.start()]
112*3117ece4Schristos
113*3117ece4Schristos        return line
114*3117ece4Schristos
115*3117ece4Schristos    def _fixup_indentation(self, macro, replace: [str]):
116*3117ece4Schristos        if len(replace) == 0:
117*3117ece4Schristos            return replace
118*3117ece4Schristos        if len(replace) == 1 and self._define.match(replace[0]) is None:
119*3117ece4Schristos            # If there is only one line, only replace defines
120*3117ece4Schristos            return replace
121*3117ece4Schristos
122*3117ece4Schristos
123*3117ece4Schristos        all_pound = True
124*3117ece4Schristos        for line in replace:
125*3117ece4Schristos            if not line.startswith('#'):
126*3117ece4Schristos                all_pound = False
127*3117ece4Schristos        if all_pound:
128*3117ece4Schristos            replace = [line[1:] for line in replace]
129*3117ece4Schristos
130*3117ece4Schristos        min_spaces = len(replace[0])
131*3117ece4Schristos        for line in replace:
132*3117ece4Schristos            spaces = 0
133*3117ece4Schristos            for i, c in enumerate(line):
134*3117ece4Schristos                if c != ' ':
135*3117ece4Schristos                    # Non-preprocessor line ==> skip the fixup
136*3117ece4Schristos                    if not all_pound and c != '#':
137*3117ece4Schristos                        return replace
138*3117ece4Schristos                    spaces = i
139*3117ece4Schristos                    break
140*3117ece4Schristos            min_spaces = min(min_spaces, spaces)
141*3117ece4Schristos
142*3117ece4Schristos        replace = [line[min_spaces:] for line in replace]
143*3117ece4Schristos
144*3117ece4Schristos        if all_pound:
145*3117ece4Schristos            replace = ["#" + line for line in replace]
146*3117ece4Schristos
147*3117ece4Schristos        return replace
148*3117ece4Schristos
149*3117ece4Schristos    def _handle_if_block(self, macro, idx, is_true, prepend):
150*3117ece4Schristos        """
151*3117ece4Schristos        Remove the #if or #elif block starting on this line.
152*3117ece4Schristos        """
153*3117ece4Schristos        REMOVE_ONE = 0
154*3117ece4Schristos        KEEP_ONE = 1
155*3117ece4Schristos        REMOVE_REST = 2
156*3117ece4Schristos
157*3117ece4Schristos        if is_true:
158*3117ece4Schristos            state = KEEP_ONE
159*3117ece4Schristos        else:
160*3117ece4Schristos            state = REMOVE_ONE
161*3117ece4Schristos
162*3117ece4Schristos        line = self._inlines[idx]
163*3117ece4Schristos        is_if = self._if.match(line) is not None
164*3117ece4Schristos        assert is_if or self._elif.match(line) is not None
165*3117ece4Schristos        depth = 0
166*3117ece4Schristos
167*3117ece4Schristos        start_idx = idx
168*3117ece4Schristos
169*3117ece4Schristos        idx += 1
170*3117ece4Schristos        replace = prepend
171*3117ece4Schristos        finished = False
172*3117ece4Schristos        while idx < len(self._inlines):
173*3117ece4Schristos            line = self._inlines[idx]
174*3117ece4Schristos            # Nested if statement
175*3117ece4Schristos            if self._if.match(line):
176*3117ece4Schristos                depth += 1
177*3117ece4Schristos                idx += 1
178*3117ece4Schristos                continue
179*3117ece4Schristos            # We're inside a nested statement
180*3117ece4Schristos            if depth > 0:
181*3117ece4Schristos                if self._endif.match(line):
182*3117ece4Schristos                    depth -= 1
183*3117ece4Schristos                idx += 1
184*3117ece4Schristos                continue
185*3117ece4Schristos
186*3117ece4Schristos            # We're at the original depth
187*3117ece4Schristos
188*3117ece4Schristos            # Looking only for an endif.
189*3117ece4Schristos            # We've found a true statement, but haven't
190*3117ece4Schristos            # completely elided the if block, so we just
191*3117ece4Schristos            # remove the remainder.
192*3117ece4Schristos            if state == REMOVE_REST:
193*3117ece4Schristos                if self._endif.match(line):
194*3117ece4Schristos                    if is_if:
195*3117ece4Schristos                        # Remove the endif because we took the first if
196*3117ece4Schristos                        idx += 1
197*3117ece4Schristos                    finished = True
198*3117ece4Schristos                    break
199*3117ece4Schristos                idx += 1
200*3117ece4Schristos                continue
201*3117ece4Schristos
202*3117ece4Schristos            if state == KEEP_ONE:
203*3117ece4Schristos                m = self._elif.match(line)
204*3117ece4Schristos                if self._endif.match(line):
205*3117ece4Schristos                    replace += self._inlines[start_idx + 1:idx]
206*3117ece4Schristos                    idx += 1
207*3117ece4Schristos                    finished = True
208*3117ece4Schristos                    break
209*3117ece4Schristos                if self._elif.match(line) or self._else.match(line):
210*3117ece4Schristos                    replace += self._inlines[start_idx + 1:idx]
211*3117ece4Schristos                    state = REMOVE_REST
212*3117ece4Schristos                idx += 1
213*3117ece4Schristos                continue
214*3117ece4Schristos
215*3117ece4Schristos            if state == REMOVE_ONE:
216*3117ece4Schristos                m = self._elif.match(line)
217*3117ece4Schristos                if m is not None:
218*3117ece4Schristos                    if is_if:
219*3117ece4Schristos                        idx += 1
220*3117ece4Schristos                        b = m.start('elif')
221*3117ece4Schristos                        e = m.end('elif')
222*3117ece4Schristos                        assert e - b == 2
223*3117ece4Schristos                        replace.append(line[:b] + line[e:])
224*3117ece4Schristos                    finished = True
225*3117ece4Schristos                    break
226*3117ece4Schristos                m = self._else.match(line)
227*3117ece4Schristos                if m is not None:
228*3117ece4Schristos                    if is_if:
229*3117ece4Schristos                        idx += 1
230*3117ece4Schristos                        while self._endif.match(self._inlines[idx]) is None:
231*3117ece4Schristos                            replace.append(self._inlines[idx])
232*3117ece4Schristos                            idx += 1
233*3117ece4Schristos                        idx += 1
234*3117ece4Schristos                    finished = True
235*3117ece4Schristos                    break
236*3117ece4Schristos                if self._endif.match(line):
237*3117ece4Schristos                    if is_if:
238*3117ece4Schristos                        # Remove the endif because no other elifs
239*3117ece4Schristos                        idx += 1
240*3117ece4Schristos                    finished = True
241*3117ece4Schristos                    break
242*3117ece4Schristos                idx += 1
243*3117ece4Schristos                continue
244*3117ece4Schristos        if not finished:
245*3117ece4Schristos            raise RuntimeError("Unterminated if block!")
246*3117ece4Schristos
247*3117ece4Schristos        replace = self._fixup_indentation(macro, replace)
248*3117ece4Schristos
249*3117ece4Schristos        self._log(f"\tHardwiring {macro}")
250*3117ece4Schristos        if start_idx > 0:
251*3117ece4Schristos            self._log(f"\t\t  {self._inlines[start_idx - 1][:-1]}")
252*3117ece4Schristos        for x in range(start_idx, idx):
253*3117ece4Schristos            self._log(f"\t\t- {self._inlines[x][:-1]}")
254*3117ece4Schristos        for line in replace:
255*3117ece4Schristos            self._log(f"\t\t+ {line[:-1]}")
256*3117ece4Schristos        if idx < len(self._inlines):
257*3117ece4Schristos            self._log(f"\t\t  {self._inlines[idx][:-1]}")
258*3117ece4Schristos
259*3117ece4Schristos        return idx, replace
260*3117ece4Schristos
261*3117ece4Schristos    def _preprocess_once(self):
262*3117ece4Schristos        outlines = []
263*3117ece4Schristos        idx = 0
264*3117ece4Schristos        changed = False
265*3117ece4Schristos        while idx < len(self._inlines):
266*3117ece4Schristos            line = self._inlines[idx]
267*3117ece4Schristos            sline = self._strip_comments(line)
268*3117ece4Schristos            m = self._ifdef.fullmatch(sline)
269*3117ece4Schristos            if_true = False
270*3117ece4Schristos            if m is None:
271*3117ece4Schristos                m = self._if_defined_value.fullmatch(sline)
272*3117ece4Schristos            if m is None:
273*3117ece4Schristos                m = self._if_defined.match(sline)
274*3117ece4Schristos            if m is None:
275*3117ece4Schristos                m = self._if_true.match(sline)
276*3117ece4Schristos                if_true = (m is not None)
277*3117ece4Schristos            if m is None:
278*3117ece4Schristos                outlines.append(line)
279*3117ece4Schristos                idx += 1
280*3117ece4Schristos                continue
281*3117ece4Schristos
282*3117ece4Schristos            groups = m.groupdict()
283*3117ece4Schristos            macro = groups['macro']
284*3117ece4Schristos            op = groups.get('op')
285*3117ece4Schristos
286*3117ece4Schristos            if not (macro in self._defs or macro in self._undefs):
287*3117ece4Schristos                outlines.append(line)
288*3117ece4Schristos                idx += 1
289*3117ece4Schristos                continue
290*3117ece4Schristos
291*3117ece4Schristos            defined = macro in self._defs
292*3117ece4Schristos
293*3117ece4Schristos            # Needed variables set:
294*3117ece4Schristos            # resolved: Is the statement fully resolved?
295*3117ece4Schristos            # is_true: If resolved, is the statement true?
296*3117ece4Schristos            ifdef = False
297*3117ece4Schristos            if if_true:
298*3117ece4Schristos                if not defined:
299*3117ece4Schristos                    outlines.append(line)
300*3117ece4Schristos                    idx += 1
301*3117ece4Schristos                    continue
302*3117ece4Schristos
303*3117ece4Schristos                defined_value = self._defs[macro]
304*3117ece4Schristos                is_int = True
305*3117ece4Schristos                try:
306*3117ece4Schristos                    defined_value = int(defined_value)
307*3117ece4Schristos                except TypeError:
308*3117ece4Schristos                    is_int = False
309*3117ece4Schristos                except ValueError:
310*3117ece4Schristos                    is_int = False
311*3117ece4Schristos
312*3117ece4Schristos                resolved = is_int
313*3117ece4Schristos                is_true = (defined_value != 0)
314*3117ece4Schristos
315*3117ece4Schristos                if resolved and op is not None:
316*3117ece4Schristos                    if op == '&&':
317*3117ece4Schristos                        resolved = not is_true
318*3117ece4Schristos                    else:
319*3117ece4Schristos                        assert op == '||'
320*3117ece4Schristos                        resolved = is_true
321*3117ece4Schristos
322*3117ece4Schristos            else:
323*3117ece4Schristos                ifdef = groups.get('not') is None
324*3117ece4Schristos                elseif = groups.get('elif') is not None
325*3117ece4Schristos
326*3117ece4Schristos                macro2 = groups.get('macro2')
327*3117ece4Schristos                cmp = groups.get('cmp')
328*3117ece4Schristos                value = groups.get('value')
329*3117ece4Schristos                openp = groups.get('openp')
330*3117ece4Schristos                closep = groups.get('closep')
331*3117ece4Schristos
332*3117ece4Schristos                is_true = (ifdef == defined)
333*3117ece4Schristos                resolved = True
334*3117ece4Schristos                if op is not None:
335*3117ece4Schristos                    if op == '&&':
336*3117ece4Schristos                        resolved = not is_true
337*3117ece4Schristos                    else:
338*3117ece4Schristos                        assert op == '||'
339*3117ece4Schristos                        resolved = is_true
340*3117ece4Schristos
341*3117ece4Schristos                if macro2 is not None and not resolved:
342*3117ece4Schristos                    assert ifdef and defined and op == '&&' and cmp is not None
343*3117ece4Schristos                    # If the statement is true, but we have a single value check, then
344*3117ece4Schristos                    # check the value.
345*3117ece4Schristos                    defined_value = self._defs[macro]
346*3117ece4Schristos                    are_ints = True
347*3117ece4Schristos                    try:
348*3117ece4Schristos                        defined_value = int(defined_value)
349*3117ece4Schristos                        value = int(value)
350*3117ece4Schristos                    except TypeError:
351*3117ece4Schristos                        are_ints = False
352*3117ece4Schristos                    except ValueError:
353*3117ece4Schristos                        are_ints = False
354*3117ece4Schristos                    if (
355*3117ece4Schristos                            macro == macro2 and
356*3117ece4Schristos                            ((openp is None) == (closep is None)) and
357*3117ece4Schristos                            are_ints
358*3117ece4Schristos                    ):
359*3117ece4Schristos                        resolved = True
360*3117ece4Schristos                        if cmp == '<':
361*3117ece4Schristos                            is_true = defined_value < value
362*3117ece4Schristos                        elif cmp == '<=':
363*3117ece4Schristos                            is_true = defined_value <= value
364*3117ece4Schristos                        elif cmp == '==':
365*3117ece4Schristos                            is_true = defined_value == value
366*3117ece4Schristos                        elif cmp == '!=':
367*3117ece4Schristos                            is_true = defined_value != value
368*3117ece4Schristos                        elif cmp == '>=':
369*3117ece4Schristos                            is_true = defined_value >= value
370*3117ece4Schristos                        elif cmp == '>':
371*3117ece4Schristos                            is_true = defined_value > value
372*3117ece4Schristos                        else:
373*3117ece4Schristos                            resolved = False
374*3117ece4Schristos
375*3117ece4Schristos                if op is not None and not resolved:
376*3117ece4Schristos                    # Remove the first op in the line + spaces
377*3117ece4Schristos                    if op == '&&':
378*3117ece4Schristos                        opre = op
379*3117ece4Schristos                    else:
380*3117ece4Schristos                        assert op == '||'
381*3117ece4Schristos                        opre = r'\|\|'
382*3117ece4Schristos                    needle = re.compile(fr"(?P<if>\s*#\s*(el)?if\s+).*?(?P<op>{opre}\s*)")
383*3117ece4Schristos                    match = needle.match(line)
384*3117ece4Schristos                    assert match is not None
385*3117ece4Schristos                    newline = line[:match.end('if')] + line[match.end('op'):]
386*3117ece4Schristos
387*3117ece4Schristos                    self._log(f"\tHardwiring partially resolved {macro}")
388*3117ece4Schristos                    self._log(f"\t\t- {line[:-1]}")
389*3117ece4Schristos                    self._log(f"\t\t+ {newline[:-1]}")
390*3117ece4Schristos
391*3117ece4Schristos                    outlines.append(newline)
392*3117ece4Schristos                    idx += 1
393*3117ece4Schristos                    continue
394*3117ece4Schristos
395*3117ece4Schristos            # Skip any statements we cannot fully compute
396*3117ece4Schristos            if not resolved:
397*3117ece4Schristos                outlines.append(line)
398*3117ece4Schristos                idx += 1
399*3117ece4Schristos                continue
400*3117ece4Schristos
401*3117ece4Schristos            prepend = []
402*3117ece4Schristos            if macro in self._replaces:
403*3117ece4Schristos                assert not ifdef
404*3117ece4Schristos                assert op is None
405*3117ece4Schristos                value = self._replaces.pop(macro)
406*3117ece4Schristos                prepend = [f"#define {macro} {value}\n"]
407*3117ece4Schristos
408*3117ece4Schristos            idx, replace = self._handle_if_block(macro, idx, is_true, prepend)
409*3117ece4Schristos            outlines += replace
410*3117ece4Schristos            changed = True
411*3117ece4Schristos
412*3117ece4Schristos        return changed, outlines
413*3117ece4Schristos
414*3117ece4Schristos    def preprocess(self, filename):
415*3117ece4Schristos        with open(filename, 'r') as f:
416*3117ece4Schristos            self._inlines = f.readlines()
417*3117ece4Schristos        changed = True
418*3117ece4Schristos        iters = 0
419*3117ece4Schristos        while changed:
420*3117ece4Schristos            iters += 1
421*3117ece4Schristos            changed, outlines = self._preprocess_once()
422*3117ece4Schristos            self._inlines = outlines
423*3117ece4Schristos
424*3117ece4Schristos        with open(filename, 'w') as f:
425*3117ece4Schristos            f.write(''.join(self._inlines))
426*3117ece4Schristos
427*3117ece4Schristos
428*3117ece4Schristosclass Freestanding(object):
429*3117ece4Schristos    def __init__(
430*3117ece4Schristos            self, zstd_deps: str, mem: str, source_lib: str, output_lib: str,
431*3117ece4Schristos            external_xxhash: bool, xxh64_state: Optional[str],
432*3117ece4Schristos            xxh64_prefix: Optional[str], rewritten_includes: [(str, str)],
433*3117ece4Schristos            defs: [(str, Optional[str])], replaces: [(str, str)],
434*3117ece4Schristos            undefs: [str], excludes: [str], seds: [str], spdx: bool,
435*3117ece4Schristos    ):
436*3117ece4Schristos        self._zstd_deps = zstd_deps
437*3117ece4Schristos        self._mem = mem
438*3117ece4Schristos        self._src_lib = source_lib
439*3117ece4Schristos        self._dst_lib = output_lib
440*3117ece4Schristos        self._external_xxhash = external_xxhash
441*3117ece4Schristos        self._xxh64_state = xxh64_state
442*3117ece4Schristos        self._xxh64_prefix = xxh64_prefix
443*3117ece4Schristos        self._rewritten_includes = rewritten_includes
444*3117ece4Schristos        self._defs = defs
445*3117ece4Schristos        self._replaces = replaces
446*3117ece4Schristos        self._undefs = undefs
447*3117ece4Schristos        self._excludes = excludes
448*3117ece4Schristos        self._seds = seds
449*3117ece4Schristos        self._spdx = spdx
450*3117ece4Schristos
451*3117ece4Schristos    def _dst_lib_file_paths(self):
452*3117ece4Schristos        """
453*3117ece4Schristos        Yields all the file paths in the dst_lib.
454*3117ece4Schristos        """
455*3117ece4Schristos        for root, dirname, filenames in os.walk(self._dst_lib):
456*3117ece4Schristos            for filename in filenames:
457*3117ece4Schristos                filepath = os.path.join(root, filename)
458*3117ece4Schristos                yield filepath
459*3117ece4Schristos
460*3117ece4Schristos    def _log(self, *args, **kwargs):
461*3117ece4Schristos        print(*args, **kwargs)
462*3117ece4Schristos
463*3117ece4Schristos    def _copy_file(self, lib_path):
464*3117ece4Schristos        suffixes = [".c", ".h", ".S"]
465*3117ece4Schristos        if not any((lib_path.endswith(suffix) for suffix in suffixes)):
466*3117ece4Schristos            return
467*3117ece4Schristos        if lib_path in SKIPPED_FILES:
468*3117ece4Schristos            self._log(f"\tSkipping file: {lib_path}")
469*3117ece4Schristos            return
470*3117ece4Schristos        if self._external_xxhash and lib_path in XXHASH_FILES:
471*3117ece4Schristos            self._log(f"\tSkipping xxhash file: {lib_path}")
472*3117ece4Schristos            return
473*3117ece4Schristos
474*3117ece4Schristos        src_path = os.path.join(self._src_lib, lib_path)
475*3117ece4Schristos        dst_path = os.path.join(self._dst_lib, lib_path)
476*3117ece4Schristos        self._log(f"\tCopying: {src_path} -> {dst_path}")
477*3117ece4Schristos        shutil.copyfile(src_path, dst_path)
478*3117ece4Schristos
479*3117ece4Schristos    def _copy_source_lib(self):
480*3117ece4Schristos        self._log("Copying source library into output library")
481*3117ece4Schristos
482*3117ece4Schristos        assert os.path.exists(self._src_lib)
483*3117ece4Schristos        os.makedirs(self._dst_lib, exist_ok=True)
484*3117ece4Schristos        self._copy_file("zstd.h")
485*3117ece4Schristos        self._copy_file("zstd_errors.h")
486*3117ece4Schristos        for subdir in INCLUDED_SUBDIRS:
487*3117ece4Schristos            src_dir = os.path.join(self._src_lib, subdir)
488*3117ece4Schristos            dst_dir = os.path.join(self._dst_lib, subdir)
489*3117ece4Schristos
490*3117ece4Schristos            assert os.path.exists(src_dir)
491*3117ece4Schristos            os.makedirs(dst_dir, exist_ok=True)
492*3117ece4Schristos
493*3117ece4Schristos            for filename in os.listdir(src_dir):
494*3117ece4Schristos                lib_path = os.path.join(subdir, filename)
495*3117ece4Schristos                self._copy_file(lib_path)
496*3117ece4Schristos
497*3117ece4Schristos    def _copy_zstd_deps(self):
498*3117ece4Schristos        dst_zstd_deps = os.path.join(self._dst_lib, "common", "zstd_deps.h")
499*3117ece4Schristos        self._log(f"Copying zstd_deps: {self._zstd_deps} -> {dst_zstd_deps}")
500*3117ece4Schristos        shutil.copyfile(self._zstd_deps, dst_zstd_deps)
501*3117ece4Schristos
502*3117ece4Schristos    def _copy_mem(self):
503*3117ece4Schristos        dst_mem = os.path.join(self._dst_lib, "common", "mem.h")
504*3117ece4Schristos        self._log(f"Copying mem: {self._mem} -> {dst_mem}")
505*3117ece4Schristos        shutil.copyfile(self._mem, dst_mem)
506*3117ece4Schristos
507*3117ece4Schristos    def _hardwire_preprocessor(self, name: str, value: Optional[str] = None, undef=False):
508*3117ece4Schristos        """
509*3117ece4Schristos        If value=None then hardwire that it is defined, but not what the value is.
510*3117ece4Schristos        If undef=True then value must be None.
511*3117ece4Schristos        If value='' then the macro is defined to '' exactly.
512*3117ece4Schristos        """
513*3117ece4Schristos        assert not (undef and value is not None)
514*3117ece4Schristos        for filepath in self._dst_lib_file_paths():
515*3117ece4Schristos            file = FileLines(filepath)
516*3117ece4Schristos
517*3117ece4Schristos    def _hardwire_defines(self):
518*3117ece4Schristos        self._log("Hardwiring macros")
519*3117ece4Schristos        partial_preprocessor = PartialPreprocessor(self._defs, self._replaces, self._undefs)
520*3117ece4Schristos        for filepath in self._dst_lib_file_paths():
521*3117ece4Schristos            partial_preprocessor.preprocess(filepath)
522*3117ece4Schristos
523*3117ece4Schristos    def _remove_excludes(self):
524*3117ece4Schristos        self._log("Removing excluded sections")
525*3117ece4Schristos        for exclude in self._excludes:
526*3117ece4Schristos            self._log(f"\tRemoving excluded sections for: {exclude}")
527*3117ece4Schristos            begin_re = re.compile(f"BEGIN {exclude}")
528*3117ece4Schristos            end_re = re.compile(f"END {exclude}")
529*3117ece4Schristos            for filepath in self._dst_lib_file_paths():
530*3117ece4Schristos                file = FileLines(filepath)
531*3117ece4Schristos                outlines = []
532*3117ece4Schristos                skipped = []
533*3117ece4Schristos                emit = True
534*3117ece4Schristos                for line in file.lines:
535*3117ece4Schristos                    if emit and begin_re.search(line) is not None:
536*3117ece4Schristos                        assert end_re.search(line) is None
537*3117ece4Schristos                        emit = False
538*3117ece4Schristos                    if emit:
539*3117ece4Schristos                        outlines.append(line)
540*3117ece4Schristos                    else:
541*3117ece4Schristos                        skipped.append(line)
542*3117ece4Schristos                        if end_re.search(line) is not None:
543*3117ece4Schristos                            assert begin_re.search(line) is None
544*3117ece4Schristos                            self._log(f"\t\tRemoving excluded section: {exclude}")
545*3117ece4Schristos                            for s in skipped:
546*3117ece4Schristos                                self._log(f"\t\t\t- {s}")
547*3117ece4Schristos                            emit = True
548*3117ece4Schristos                            skipped = []
549*3117ece4Schristos                if not emit:
550*3117ece4Schristos                    raise RuntimeError("Excluded section unfinished!")
551*3117ece4Schristos                file.lines = outlines
552*3117ece4Schristos                file.write()
553*3117ece4Schristos
554*3117ece4Schristos    def _rewrite_include(self, original, rewritten):
555*3117ece4Schristos        self._log(f"\tRewriting include: {original} -> {rewritten}")
556*3117ece4Schristos        regex = re.compile(f"\\s*#\\s*include\\s*(?P<include>{original})")
557*3117ece4Schristos        for filepath in self._dst_lib_file_paths():
558*3117ece4Schristos            file = FileLines(filepath)
559*3117ece4Schristos            for i, line in enumerate(file.lines):
560*3117ece4Schristos                match = regex.match(line)
561*3117ece4Schristos                if match is None:
562*3117ece4Schristos                    continue
563*3117ece4Schristos                s = match.start('include')
564*3117ece4Schristos                e = match.end('include')
565*3117ece4Schristos                file.lines[i] = line[:s] + rewritten + line[e:]
566*3117ece4Schristos            file.write()
567*3117ece4Schristos
568*3117ece4Schristos    def _rewrite_includes(self):
569*3117ece4Schristos        self._log("Rewriting includes")
570*3117ece4Schristos        for original, rewritten in self._rewritten_includes:
571*3117ece4Schristos            self._rewrite_include(original, rewritten)
572*3117ece4Schristos
573*3117ece4Schristos    def _replace_xxh64_prefix(self):
574*3117ece4Schristos        if self._xxh64_prefix is None:
575*3117ece4Schristos            return
576*3117ece4Schristos        self._log(f"Replacing XXH64 prefix with {self._xxh64_prefix}")
577*3117ece4Schristos        replacements = []
578*3117ece4Schristos        if self._xxh64_state is not None:
579*3117ece4Schristos            replacements.append(
580*3117ece4Schristos                (re.compile(r"([^\w]|^)(?P<orig>XXH64_state_t)([^\w]|$)"), self._xxh64_state)
581*3117ece4Schristos            )
582*3117ece4Schristos        if self._xxh64_prefix is not None:
583*3117ece4Schristos            replacements.append(
584*3117ece4Schristos                (re.compile(r"([^\w]|^)(?P<orig>XXH64)[\(_]"), self._xxh64_prefix)
585*3117ece4Schristos            )
586*3117ece4Schristos        for filepath in self._dst_lib_file_paths():
587*3117ece4Schristos            file = FileLines(filepath)
588*3117ece4Schristos            for i, line in enumerate(file.lines):
589*3117ece4Schristos                modified = False
590*3117ece4Schristos                for regex, replacement in replacements:
591*3117ece4Schristos                    match = regex.search(line)
592*3117ece4Schristos                    while match is not None:
593*3117ece4Schristos                        modified = True
594*3117ece4Schristos                        b = match.start('orig')
595*3117ece4Schristos                        e = match.end('orig')
596*3117ece4Schristos                        line = line[:b] + replacement + line[e:]
597*3117ece4Schristos                        match = regex.search(line)
598*3117ece4Schristos                if modified:
599*3117ece4Schristos                    self._log(f"\t- {file.lines[i][:-1]}")
600*3117ece4Schristos                    self._log(f"\t+ {line[:-1]}")
601*3117ece4Schristos                file.lines[i] = line
602*3117ece4Schristos            file.write()
603*3117ece4Schristos
604*3117ece4Schristos    def _parse_sed(self, sed):
605*3117ece4Schristos        assert sed[0] == 's'
606*3117ece4Schristos        delim = sed[1]
607*3117ece4Schristos        match = re.fullmatch(f's{delim}(.+){delim}(.*){delim}(.*)', sed)
608*3117ece4Schristos        assert match is not None
609*3117ece4Schristos        regex = re.compile(match.group(1))
610*3117ece4Schristos        format_str = match.group(2)
611*3117ece4Schristos        is_global = match.group(3) == 'g'
612*3117ece4Schristos        return regex, format_str, is_global
613*3117ece4Schristos
614*3117ece4Schristos    def _process_sed(self, sed):
615*3117ece4Schristos        self._log(f"Processing sed: {sed}")
616*3117ece4Schristos        regex, format_str, is_global = self._parse_sed(sed)
617*3117ece4Schristos
618*3117ece4Schristos        for filepath in self._dst_lib_file_paths():
619*3117ece4Schristos            file = FileLines(filepath)
620*3117ece4Schristos            for i, line in enumerate(file.lines):
621*3117ece4Schristos                modified = False
622*3117ece4Schristos                while True:
623*3117ece4Schristos                    match = regex.search(line)
624*3117ece4Schristos                    if match is None:
625*3117ece4Schristos                        break
626*3117ece4Schristos                    replacement = format_str.format(match.groups(''), match.groupdict(''))
627*3117ece4Schristos                    b = match.start()
628*3117ece4Schristos                    e = match.end()
629*3117ece4Schristos                    line = line[:b] + replacement + line[e:]
630*3117ece4Schristos                    modified = True
631*3117ece4Schristos                    if not is_global:
632*3117ece4Schristos                        break
633*3117ece4Schristos                if modified:
634*3117ece4Schristos                    self._log(f"\t- {file.lines[i][:-1]}")
635*3117ece4Schristos                    self._log(f"\t+ {line[:-1]}")
636*3117ece4Schristos                file.lines[i] = line
637*3117ece4Schristos            file.write()
638*3117ece4Schristos
639*3117ece4Schristos    def _process_seds(self):
640*3117ece4Schristos        self._log("Processing seds")
641*3117ece4Schristos        for sed in self._seds:
642*3117ece4Schristos            self._process_sed(sed)
643*3117ece4Schristos
644*3117ece4Schristos    def _process_spdx(self):
645*3117ece4Schristos        if not self._spdx:
646*3117ece4Schristos            return
647*3117ece4Schristos        self._log("Processing spdx")
648*3117ece4Schristos        SPDX_C = "// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause\n"
649*3117ece4Schristos        SPDX_H_S = "/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */\n"
650*3117ece4Schristos        for filepath in self._dst_lib_file_paths():
651*3117ece4Schristos            file = FileLines(filepath)
652*3117ece4Schristos            if file.lines[0] == SPDX_C or file.lines[0] == SPDX_H_S:
653*3117ece4Schristos                continue
654*3117ece4Schristos            for line in file.lines:
655*3117ece4Schristos                if "SPDX-License-Identifier" in line:
656*3117ece4Schristos                    raise RuntimeError(f"Unexpected SPDX license identifier: {file.filename} {repr(line)}")
657*3117ece4Schristos            if file.filename.endswith(".c"):
658*3117ece4Schristos                file.lines.insert(0, SPDX_C)
659*3117ece4Schristos            elif file.filename.endswith(".h") or file.filename.endswith(".S"):
660*3117ece4Schristos                file.lines.insert(0, SPDX_H_S)
661*3117ece4Schristos            else:
662*3117ece4Schristos                raise RuntimeError(f"Unexpected file extension: {file.filename}")
663*3117ece4Schristos            file.write()
664*3117ece4Schristos
665*3117ece4Schristos
666*3117ece4Schristos
667*3117ece4Schristos    def go(self):
668*3117ece4Schristos        self._copy_source_lib()
669*3117ece4Schristos        self._copy_zstd_deps()
670*3117ece4Schristos        self._copy_mem()
671*3117ece4Schristos        self._hardwire_defines()
672*3117ece4Schristos        self._remove_excludes()
673*3117ece4Schristos        self._rewrite_includes()
674*3117ece4Schristos        self._replace_xxh64_prefix()
675*3117ece4Schristos        self._process_seds()
676*3117ece4Schristos        self._process_spdx()
677*3117ece4Schristos
678*3117ece4Schristos
679*3117ece4Schristosdef parse_optional_pair(defines: [str]) -> [(str, Optional[str])]:
680*3117ece4Schristos    output = []
681*3117ece4Schristos    for define in defines:
682*3117ece4Schristos        parsed = define.split('=')
683*3117ece4Schristos        if len(parsed) == 1:
684*3117ece4Schristos            output.append((parsed[0], None))
685*3117ece4Schristos        elif len(parsed) == 2:
686*3117ece4Schristos            output.append((parsed[0], parsed[1]))
687*3117ece4Schristos        else:
688*3117ece4Schristos            raise RuntimeError(f"Bad define: {define}")
689*3117ece4Schristos    return output
690*3117ece4Schristos
691*3117ece4Schristos
692*3117ece4Schristosdef parse_pair(rewritten_includes: [str]) -> [(str, str)]:
693*3117ece4Schristos    output = []
694*3117ece4Schristos    for rewritten_include in rewritten_includes:
695*3117ece4Schristos        parsed = rewritten_include.split('=')
696*3117ece4Schristos        if len(parsed) == 2:
697*3117ece4Schristos            output.append((parsed[0], parsed[1]))
698*3117ece4Schristos        else:
699*3117ece4Schristos            raise RuntimeError(f"Bad rewritten include: {rewritten_include}")
700*3117ece4Schristos    return output
701*3117ece4Schristos
702*3117ece4Schristos
703*3117ece4Schristos
704*3117ece4Schristosdef main(name, args):
705*3117ece4Schristos    parser = argparse.ArgumentParser(prog=name)
706*3117ece4Schristos    parser.add_argument("--zstd-deps", default="zstd_deps.h", help="Zstd dependencies file")
707*3117ece4Schristos    parser.add_argument("--mem", default="mem.h", help="Memory module")
708*3117ece4Schristos    parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library")
709*3117ece4Schristos    parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library")
710*3117ece4Schristos    parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash='<xxhash.h>'. If set xxhash is not included.")
711*3117ece4Schristos    parser.add_argument("--xxh64-state", default=None, help="Alternate XXH64 state type (excluding _) e.g. --xxh64-state='struct xxh64_state'")
712*3117ece4Schristos    parser.add_argument("--xxh64-prefix", default=None, help="Alternate XXH64 function prefix (excluding _) e.g. --xxh64-prefix=xxh64")
713*3117ece4Schristos    parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '<stddef\\.h>=<linux/types.h>')")
714*3117ece4Schristos    parser.add_argument("--sed", default=[], dest="seds", action="append", help="Apply a sed replacement. Format: `s/REGEX/FORMAT/[g]`. REGEX is a Python regex. FORMAT is a Python format string formatted by the regex dict.")
715*3117ece4Schristos    parser.add_argument("--spdx", action="store_true", help="Add SPDX License Identifiers")
716*3117ece4Schristos    parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)")
717*3117ece4Schristos    parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed multiple times)")
718*3117ece4Schristos    parser.add_argument("-R", "--replace", default=[], dest="replaces", action="append", help="Pre-define this macro and replace the first ifndef block with its definition")
719*3117ece4Schristos    parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN <EXCLUDE>' and 'END <EXCLUDE>'")
720*3117ece4Schristos    args = parser.parse_args(args)
721*3117ece4Schristos
722*3117ece4Schristos    # Always remove threading
723*3117ece4Schristos    if "ZSTD_MULTITHREAD" not in args.undefs:
724*3117ece4Schristos        args.undefs.append("ZSTD_MULTITHREAD")
725*3117ece4Schristos
726*3117ece4Schristos    args.defs = parse_optional_pair(args.defs)
727*3117ece4Schristos    for name, _ in args.defs:
728*3117ece4Schristos        if name in args.undefs:
729*3117ece4Schristos            raise RuntimeError(f"{name} is both defined and undefined!")
730*3117ece4Schristos
731*3117ece4Schristos    # Always set tracing to 0
732*3117ece4Schristos    if "ZSTD_NO_TRACE" not in (arg[0] for arg in args.defs):
733*3117ece4Schristos        args.defs.append(("ZSTD_NO_TRACE", None))
734*3117ece4Schristos        args.defs.append(("ZSTD_TRACE", "0"))
735*3117ece4Schristos
736*3117ece4Schristos    args.replaces = parse_pair(args.replaces)
737*3117ece4Schristos    for name, _ in args.replaces:
738*3117ece4Schristos        if name in args.undefs or name in args.defs:
739*3117ece4Schristos            raise RuntimeError(f"{name} is both replaced and (un)defined!")
740*3117ece4Schristos
741*3117ece4Schristos    args.rewritten_includes = parse_pair(args.rewritten_includes)
742*3117ece4Schristos
743*3117ece4Schristos    external_xxhash = False
744*3117ece4Schristos    if args.xxhash is not None:
745*3117ece4Schristos        external_xxhash = True
746*3117ece4Schristos        args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash))
747*3117ece4Schristos
748*3117ece4Schristos    if args.xxh64_prefix is not None:
749*3117ece4Schristos        if not external_xxhash:
750*3117ece4Schristos            raise RuntimeError("--xxh64-prefix may only be used with --xxhash provided")
751*3117ece4Schristos
752*3117ece4Schristos    if args.xxh64_state is not None:
753*3117ece4Schristos        if not external_xxhash:
754*3117ece4Schristos            raise RuntimeError("--xxh64-state may only be used with --xxhash provided")
755*3117ece4Schristos
756*3117ece4Schristos    Freestanding(
757*3117ece4Schristos        args.zstd_deps,
758*3117ece4Schristos        args.mem,
759*3117ece4Schristos        args.source_lib,
760*3117ece4Schristos        args.output_lib,
761*3117ece4Schristos        external_xxhash,
762*3117ece4Schristos        args.xxh64_state,
763*3117ece4Schristos        args.xxh64_prefix,
764*3117ece4Schristos        args.rewritten_includes,
765*3117ece4Schristos        args.defs,
766*3117ece4Schristos        args.replaces,
767*3117ece4Schristos        args.undefs,
768*3117ece4Schristos        args.excludes,
769*3117ece4Schristos        args.seds,
770*3117ece4Schristos        args.spdx,
771*3117ece4Schristos    ).go()
772*3117ece4Schristos
773*3117ece4Schristosif __name__ == "__main__":
774*3117ece4Schristos    main(sys.argv[0], sys.argv[1:])
775