1*3117ece4Schristos#!/usr/bin/env python3 2*3117ece4Schristos# ################################################################ 3*3117ece4Schristos# Copyright (c) Meta Platforms, Inc. and affiliates. 4*3117ece4Schristos# All rights reserved. 5*3117ece4Schristos# 6*3117ece4Schristos# This source code is licensed under both the BSD-style license (found in the 7*3117ece4Schristos# LICENSE file in the root directory of this source tree) and the GPLv2 (found 8*3117ece4Schristos# in the COPYING file in the root directory of this source tree). 9*3117ece4Schristos# You may select, at your option, one of the above-listed licenses. 10*3117ece4Schristos# ########################################################################## 11*3117ece4Schristos 12*3117ece4Schristosimport argparse 13*3117ece4Schristosimport contextlib 14*3117ece4Schristosimport os 15*3117ece4Schristosimport re 16*3117ece4Schristosimport shutil 17*3117ece4Schristosimport sys 18*3117ece4Schristosfrom typing import Optional 19*3117ece4Schristos 20*3117ece4Schristos 21*3117ece4SchristosINCLUDED_SUBDIRS = ["common", "compress", "decompress"] 22*3117ece4Schristos 23*3117ece4SchristosSKIPPED_FILES = [ 24*3117ece4Schristos "common/mem.h", 25*3117ece4Schristos "common/zstd_deps.h", 26*3117ece4Schristos "common/pool.c", 27*3117ece4Schristos "common/pool.h", 28*3117ece4Schristos "common/threading.c", 29*3117ece4Schristos "common/threading.h", 30*3117ece4Schristos "common/zstd_trace.h", 31*3117ece4Schristos "compress/zstdmt_compress.h", 32*3117ece4Schristos "compress/zstdmt_compress.c", 33*3117ece4Schristos] 34*3117ece4Schristos 35*3117ece4SchristosXXHASH_FILES = [ 36*3117ece4Schristos "common/xxhash.c", 37*3117ece4Schristos "common/xxhash.h", 38*3117ece4Schristos] 39*3117ece4Schristos 40*3117ece4Schristos 41*3117ece4Schristosclass FileLines(object): 42*3117ece4Schristos def __init__(self, filename): 43*3117ece4Schristos self.filename = filename 44*3117ece4Schristos with open(self.filename, "r") as f: 45*3117ece4Schristos self.lines = f.readlines() 46*3117ece4Schristos 47*3117ece4Schristos def write(self): 48*3117ece4Schristos with open(self.filename, "w") as f: 49*3117ece4Schristos f.write("".join(self.lines)) 50*3117ece4Schristos 51*3117ece4Schristos 52*3117ece4Schristosclass PartialPreprocessor(object): 53*3117ece4Schristos """ 54*3117ece4Schristos Looks for simple ifdefs and ifndefs and replaces them. 55*3117ece4Schristos Handles && and ||. 56*3117ece4Schristos Has fancy logic to handle translating elifs to ifs. 57*3117ece4Schristos Only looks for macros in the first part of the expression with no 58*3117ece4Schristos parens. 59*3117ece4Schristos Does not handle multi-line macros (only looks in first line). 60*3117ece4Schristos """ 61*3117ece4Schristos def __init__(self, defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str]): 62*3117ece4Schristos MACRO_GROUP = r"(?P<macro>[a-zA-Z_][a-zA-Z_0-9]*)" 63*3117ece4Schristos ELIF_GROUP = r"(?P<elif>el)?" 64*3117ece4Schristos OP_GROUP = r"(?P<op>&&|\|\|)?" 65*3117ece4Schristos 66*3117ece4Schristos self._defs = {macro:value for macro, value in defs} 67*3117ece4Schristos self._replaces = {macro:value for macro, value in replaces} 68*3117ece4Schristos self._defs.update(self._replaces) 69*3117ece4Schristos self._undefs = set(undefs) 70*3117ece4Schristos 71*3117ece4Schristos self._define = re.compile(r"\s*#\s*define") 72*3117ece4Schristos self._if = re.compile(r"\s*#\s*if") 73*3117ece4Schristos self._elif = re.compile(r"\s*#\s*(?P<elif>el)if") 74*3117ece4Schristos self._else = re.compile(r"\s*#\s*(?P<else>else)") 75*3117ece4Schristos self._endif = re.compile(r"\s*#\s*endif") 76*3117ece4Schristos 77*3117ece4Schristos self._ifdef = re.compile(fr"\s*#\s*if(?P<not>n)?def {MACRO_GROUP}\s*") 78*3117ece4Schristos self._if_defined = re.compile( 79*3117ece4Schristos fr"\s*#\s*{ELIF_GROUP}if\s+(?P<not>!)?\s*defined\s*\(\s*{MACRO_GROUP}\s*\)\s*{OP_GROUP}" 80*3117ece4Schristos ) 81*3117ece4Schristos self._if_defined_value = re.compile( 82*3117ece4Schristos fr"\s*#\s*{ELIF_GROUP}if\s+defined\s*\(\s*{MACRO_GROUP}\s*\)\s*" 83*3117ece4Schristos fr"(?P<op>&&)\s*" 84*3117ece4Schristos fr"(?P<openp>\()?\s*" 85*3117ece4Schristos fr"(?P<macro2>[a-zA-Z_][a-zA-Z_0-9]*)\s*" 86*3117ece4Schristos fr"(?P<cmp>[=><!]+)\s*" 87*3117ece4Schristos fr"(?P<value>[0-9]*)\s*" 88*3117ece4Schristos fr"(?P<closep>\))?\s*" 89*3117ece4Schristos ) 90*3117ece4Schristos self._if_true = re.compile( 91*3117ece4Schristos fr"\s*#\s*{ELIF_GROUP}if\s+{MACRO_GROUP}\s*{OP_GROUP}" 92*3117ece4Schristos ) 93*3117ece4Schristos 94*3117ece4Schristos self._c_comment = re.compile(r"/\*.*?\*/") 95*3117ece4Schristos self._cpp_comment = re.compile(r"//") 96*3117ece4Schristos 97*3117ece4Schristos def _log(self, *args, **kwargs): 98*3117ece4Schristos print(*args, **kwargs) 99*3117ece4Schristos 100*3117ece4Schristos def _strip_comments(self, line): 101*3117ece4Schristos # First strip c-style comments (may include //) 102*3117ece4Schristos while True: 103*3117ece4Schristos m = self._c_comment.search(line) 104*3117ece4Schristos if m is None: 105*3117ece4Schristos break 106*3117ece4Schristos line = line[:m.start()] + line[m.end():] 107*3117ece4Schristos 108*3117ece4Schristos # Then strip cpp-style comments 109*3117ece4Schristos m = self._cpp_comment.search(line) 110*3117ece4Schristos if m is not None: 111*3117ece4Schristos line = line[:m.start()] 112*3117ece4Schristos 113*3117ece4Schristos return line 114*3117ece4Schristos 115*3117ece4Schristos def _fixup_indentation(self, macro, replace: [str]): 116*3117ece4Schristos if len(replace) == 0: 117*3117ece4Schristos return replace 118*3117ece4Schristos if len(replace) == 1 and self._define.match(replace[0]) is None: 119*3117ece4Schristos # If there is only one line, only replace defines 120*3117ece4Schristos return replace 121*3117ece4Schristos 122*3117ece4Schristos 123*3117ece4Schristos all_pound = True 124*3117ece4Schristos for line in replace: 125*3117ece4Schristos if not line.startswith('#'): 126*3117ece4Schristos all_pound = False 127*3117ece4Schristos if all_pound: 128*3117ece4Schristos replace = [line[1:] for line in replace] 129*3117ece4Schristos 130*3117ece4Schristos min_spaces = len(replace[0]) 131*3117ece4Schristos for line in replace: 132*3117ece4Schristos spaces = 0 133*3117ece4Schristos for i, c in enumerate(line): 134*3117ece4Schristos if c != ' ': 135*3117ece4Schristos # Non-preprocessor line ==> skip the fixup 136*3117ece4Schristos if not all_pound and c != '#': 137*3117ece4Schristos return replace 138*3117ece4Schristos spaces = i 139*3117ece4Schristos break 140*3117ece4Schristos min_spaces = min(min_spaces, spaces) 141*3117ece4Schristos 142*3117ece4Schristos replace = [line[min_spaces:] for line in replace] 143*3117ece4Schristos 144*3117ece4Schristos if all_pound: 145*3117ece4Schristos replace = ["#" + line for line in replace] 146*3117ece4Schristos 147*3117ece4Schristos return replace 148*3117ece4Schristos 149*3117ece4Schristos def _handle_if_block(self, macro, idx, is_true, prepend): 150*3117ece4Schristos """ 151*3117ece4Schristos Remove the #if or #elif block starting on this line. 152*3117ece4Schristos """ 153*3117ece4Schristos REMOVE_ONE = 0 154*3117ece4Schristos KEEP_ONE = 1 155*3117ece4Schristos REMOVE_REST = 2 156*3117ece4Schristos 157*3117ece4Schristos if is_true: 158*3117ece4Schristos state = KEEP_ONE 159*3117ece4Schristos else: 160*3117ece4Schristos state = REMOVE_ONE 161*3117ece4Schristos 162*3117ece4Schristos line = self._inlines[idx] 163*3117ece4Schristos is_if = self._if.match(line) is not None 164*3117ece4Schristos assert is_if or self._elif.match(line) is not None 165*3117ece4Schristos depth = 0 166*3117ece4Schristos 167*3117ece4Schristos start_idx = idx 168*3117ece4Schristos 169*3117ece4Schristos idx += 1 170*3117ece4Schristos replace = prepend 171*3117ece4Schristos finished = False 172*3117ece4Schristos while idx < len(self._inlines): 173*3117ece4Schristos line = self._inlines[idx] 174*3117ece4Schristos # Nested if statement 175*3117ece4Schristos if self._if.match(line): 176*3117ece4Schristos depth += 1 177*3117ece4Schristos idx += 1 178*3117ece4Schristos continue 179*3117ece4Schristos # We're inside a nested statement 180*3117ece4Schristos if depth > 0: 181*3117ece4Schristos if self._endif.match(line): 182*3117ece4Schristos depth -= 1 183*3117ece4Schristos idx += 1 184*3117ece4Schristos continue 185*3117ece4Schristos 186*3117ece4Schristos # We're at the original depth 187*3117ece4Schristos 188*3117ece4Schristos # Looking only for an endif. 189*3117ece4Schristos # We've found a true statement, but haven't 190*3117ece4Schristos # completely elided the if block, so we just 191*3117ece4Schristos # remove the remainder. 192*3117ece4Schristos if state == REMOVE_REST: 193*3117ece4Schristos if self._endif.match(line): 194*3117ece4Schristos if is_if: 195*3117ece4Schristos # Remove the endif because we took the first if 196*3117ece4Schristos idx += 1 197*3117ece4Schristos finished = True 198*3117ece4Schristos break 199*3117ece4Schristos idx += 1 200*3117ece4Schristos continue 201*3117ece4Schristos 202*3117ece4Schristos if state == KEEP_ONE: 203*3117ece4Schristos m = self._elif.match(line) 204*3117ece4Schristos if self._endif.match(line): 205*3117ece4Schristos replace += self._inlines[start_idx + 1:idx] 206*3117ece4Schristos idx += 1 207*3117ece4Schristos finished = True 208*3117ece4Schristos break 209*3117ece4Schristos if self._elif.match(line) or self._else.match(line): 210*3117ece4Schristos replace += self._inlines[start_idx + 1:idx] 211*3117ece4Schristos state = REMOVE_REST 212*3117ece4Schristos idx += 1 213*3117ece4Schristos continue 214*3117ece4Schristos 215*3117ece4Schristos if state == REMOVE_ONE: 216*3117ece4Schristos m = self._elif.match(line) 217*3117ece4Schristos if m is not None: 218*3117ece4Schristos if is_if: 219*3117ece4Schristos idx += 1 220*3117ece4Schristos b = m.start('elif') 221*3117ece4Schristos e = m.end('elif') 222*3117ece4Schristos assert e - b == 2 223*3117ece4Schristos replace.append(line[:b] + line[e:]) 224*3117ece4Schristos finished = True 225*3117ece4Schristos break 226*3117ece4Schristos m = self._else.match(line) 227*3117ece4Schristos if m is not None: 228*3117ece4Schristos if is_if: 229*3117ece4Schristos idx += 1 230*3117ece4Schristos while self._endif.match(self._inlines[idx]) is None: 231*3117ece4Schristos replace.append(self._inlines[idx]) 232*3117ece4Schristos idx += 1 233*3117ece4Schristos idx += 1 234*3117ece4Schristos finished = True 235*3117ece4Schristos break 236*3117ece4Schristos if self._endif.match(line): 237*3117ece4Schristos if is_if: 238*3117ece4Schristos # Remove the endif because no other elifs 239*3117ece4Schristos idx += 1 240*3117ece4Schristos finished = True 241*3117ece4Schristos break 242*3117ece4Schristos idx += 1 243*3117ece4Schristos continue 244*3117ece4Schristos if not finished: 245*3117ece4Schristos raise RuntimeError("Unterminated if block!") 246*3117ece4Schristos 247*3117ece4Schristos replace = self._fixup_indentation(macro, replace) 248*3117ece4Schristos 249*3117ece4Schristos self._log(f"\tHardwiring {macro}") 250*3117ece4Schristos if start_idx > 0: 251*3117ece4Schristos self._log(f"\t\t {self._inlines[start_idx - 1][:-1]}") 252*3117ece4Schristos for x in range(start_idx, idx): 253*3117ece4Schristos self._log(f"\t\t- {self._inlines[x][:-1]}") 254*3117ece4Schristos for line in replace: 255*3117ece4Schristos self._log(f"\t\t+ {line[:-1]}") 256*3117ece4Schristos if idx < len(self._inlines): 257*3117ece4Schristos self._log(f"\t\t {self._inlines[idx][:-1]}") 258*3117ece4Schristos 259*3117ece4Schristos return idx, replace 260*3117ece4Schristos 261*3117ece4Schristos def _preprocess_once(self): 262*3117ece4Schristos outlines = [] 263*3117ece4Schristos idx = 0 264*3117ece4Schristos changed = False 265*3117ece4Schristos while idx < len(self._inlines): 266*3117ece4Schristos line = self._inlines[idx] 267*3117ece4Schristos sline = self._strip_comments(line) 268*3117ece4Schristos m = self._ifdef.fullmatch(sline) 269*3117ece4Schristos if_true = False 270*3117ece4Schristos if m is None: 271*3117ece4Schristos m = self._if_defined_value.fullmatch(sline) 272*3117ece4Schristos if m is None: 273*3117ece4Schristos m = self._if_defined.match(sline) 274*3117ece4Schristos if m is None: 275*3117ece4Schristos m = self._if_true.match(sline) 276*3117ece4Schristos if_true = (m is not None) 277*3117ece4Schristos if m is None: 278*3117ece4Schristos outlines.append(line) 279*3117ece4Schristos idx += 1 280*3117ece4Schristos continue 281*3117ece4Schristos 282*3117ece4Schristos groups = m.groupdict() 283*3117ece4Schristos macro = groups['macro'] 284*3117ece4Schristos op = groups.get('op') 285*3117ece4Schristos 286*3117ece4Schristos if not (macro in self._defs or macro in self._undefs): 287*3117ece4Schristos outlines.append(line) 288*3117ece4Schristos idx += 1 289*3117ece4Schristos continue 290*3117ece4Schristos 291*3117ece4Schristos defined = macro in self._defs 292*3117ece4Schristos 293*3117ece4Schristos # Needed variables set: 294*3117ece4Schristos # resolved: Is the statement fully resolved? 295*3117ece4Schristos # is_true: If resolved, is the statement true? 296*3117ece4Schristos ifdef = False 297*3117ece4Schristos if if_true: 298*3117ece4Schristos if not defined: 299*3117ece4Schristos outlines.append(line) 300*3117ece4Schristos idx += 1 301*3117ece4Schristos continue 302*3117ece4Schristos 303*3117ece4Schristos defined_value = self._defs[macro] 304*3117ece4Schristos is_int = True 305*3117ece4Schristos try: 306*3117ece4Schristos defined_value = int(defined_value) 307*3117ece4Schristos except TypeError: 308*3117ece4Schristos is_int = False 309*3117ece4Schristos except ValueError: 310*3117ece4Schristos is_int = False 311*3117ece4Schristos 312*3117ece4Schristos resolved = is_int 313*3117ece4Schristos is_true = (defined_value != 0) 314*3117ece4Schristos 315*3117ece4Schristos if resolved and op is not None: 316*3117ece4Schristos if op == '&&': 317*3117ece4Schristos resolved = not is_true 318*3117ece4Schristos else: 319*3117ece4Schristos assert op == '||' 320*3117ece4Schristos resolved = is_true 321*3117ece4Schristos 322*3117ece4Schristos else: 323*3117ece4Schristos ifdef = groups.get('not') is None 324*3117ece4Schristos elseif = groups.get('elif') is not None 325*3117ece4Schristos 326*3117ece4Schristos macro2 = groups.get('macro2') 327*3117ece4Schristos cmp = groups.get('cmp') 328*3117ece4Schristos value = groups.get('value') 329*3117ece4Schristos openp = groups.get('openp') 330*3117ece4Schristos closep = groups.get('closep') 331*3117ece4Schristos 332*3117ece4Schristos is_true = (ifdef == defined) 333*3117ece4Schristos resolved = True 334*3117ece4Schristos if op is not None: 335*3117ece4Schristos if op == '&&': 336*3117ece4Schristos resolved = not is_true 337*3117ece4Schristos else: 338*3117ece4Schristos assert op == '||' 339*3117ece4Schristos resolved = is_true 340*3117ece4Schristos 341*3117ece4Schristos if macro2 is not None and not resolved: 342*3117ece4Schristos assert ifdef and defined and op == '&&' and cmp is not None 343*3117ece4Schristos # If the statement is true, but we have a single value check, then 344*3117ece4Schristos # check the value. 345*3117ece4Schristos defined_value = self._defs[macro] 346*3117ece4Schristos are_ints = True 347*3117ece4Schristos try: 348*3117ece4Schristos defined_value = int(defined_value) 349*3117ece4Schristos value = int(value) 350*3117ece4Schristos except TypeError: 351*3117ece4Schristos are_ints = False 352*3117ece4Schristos except ValueError: 353*3117ece4Schristos are_ints = False 354*3117ece4Schristos if ( 355*3117ece4Schristos macro == macro2 and 356*3117ece4Schristos ((openp is None) == (closep is None)) and 357*3117ece4Schristos are_ints 358*3117ece4Schristos ): 359*3117ece4Schristos resolved = True 360*3117ece4Schristos if cmp == '<': 361*3117ece4Schristos is_true = defined_value < value 362*3117ece4Schristos elif cmp == '<=': 363*3117ece4Schristos is_true = defined_value <= value 364*3117ece4Schristos elif cmp == '==': 365*3117ece4Schristos is_true = defined_value == value 366*3117ece4Schristos elif cmp == '!=': 367*3117ece4Schristos is_true = defined_value != value 368*3117ece4Schristos elif cmp == '>=': 369*3117ece4Schristos is_true = defined_value >= value 370*3117ece4Schristos elif cmp == '>': 371*3117ece4Schristos is_true = defined_value > value 372*3117ece4Schristos else: 373*3117ece4Schristos resolved = False 374*3117ece4Schristos 375*3117ece4Schristos if op is not None and not resolved: 376*3117ece4Schristos # Remove the first op in the line + spaces 377*3117ece4Schristos if op == '&&': 378*3117ece4Schristos opre = op 379*3117ece4Schristos else: 380*3117ece4Schristos assert op == '||' 381*3117ece4Schristos opre = r'\|\|' 382*3117ece4Schristos needle = re.compile(fr"(?P<if>\s*#\s*(el)?if\s+).*?(?P<op>{opre}\s*)") 383*3117ece4Schristos match = needle.match(line) 384*3117ece4Schristos assert match is not None 385*3117ece4Schristos newline = line[:match.end('if')] + line[match.end('op'):] 386*3117ece4Schristos 387*3117ece4Schristos self._log(f"\tHardwiring partially resolved {macro}") 388*3117ece4Schristos self._log(f"\t\t- {line[:-1]}") 389*3117ece4Schristos self._log(f"\t\t+ {newline[:-1]}") 390*3117ece4Schristos 391*3117ece4Schristos outlines.append(newline) 392*3117ece4Schristos idx += 1 393*3117ece4Schristos continue 394*3117ece4Schristos 395*3117ece4Schristos # Skip any statements we cannot fully compute 396*3117ece4Schristos if not resolved: 397*3117ece4Schristos outlines.append(line) 398*3117ece4Schristos idx += 1 399*3117ece4Schristos continue 400*3117ece4Schristos 401*3117ece4Schristos prepend = [] 402*3117ece4Schristos if macro in self._replaces: 403*3117ece4Schristos assert not ifdef 404*3117ece4Schristos assert op is None 405*3117ece4Schristos value = self._replaces.pop(macro) 406*3117ece4Schristos prepend = [f"#define {macro} {value}\n"] 407*3117ece4Schristos 408*3117ece4Schristos idx, replace = self._handle_if_block(macro, idx, is_true, prepend) 409*3117ece4Schristos outlines += replace 410*3117ece4Schristos changed = True 411*3117ece4Schristos 412*3117ece4Schristos return changed, outlines 413*3117ece4Schristos 414*3117ece4Schristos def preprocess(self, filename): 415*3117ece4Schristos with open(filename, 'r') as f: 416*3117ece4Schristos self._inlines = f.readlines() 417*3117ece4Schristos changed = True 418*3117ece4Schristos iters = 0 419*3117ece4Schristos while changed: 420*3117ece4Schristos iters += 1 421*3117ece4Schristos changed, outlines = self._preprocess_once() 422*3117ece4Schristos self._inlines = outlines 423*3117ece4Schristos 424*3117ece4Schristos with open(filename, 'w') as f: 425*3117ece4Schristos f.write(''.join(self._inlines)) 426*3117ece4Schristos 427*3117ece4Schristos 428*3117ece4Schristosclass Freestanding(object): 429*3117ece4Schristos def __init__( 430*3117ece4Schristos self, zstd_deps: str, mem: str, source_lib: str, output_lib: str, 431*3117ece4Schristos external_xxhash: bool, xxh64_state: Optional[str], 432*3117ece4Schristos xxh64_prefix: Optional[str], rewritten_includes: [(str, str)], 433*3117ece4Schristos defs: [(str, Optional[str])], replaces: [(str, str)], 434*3117ece4Schristos undefs: [str], excludes: [str], seds: [str], spdx: bool, 435*3117ece4Schristos ): 436*3117ece4Schristos self._zstd_deps = zstd_deps 437*3117ece4Schristos self._mem = mem 438*3117ece4Schristos self._src_lib = source_lib 439*3117ece4Schristos self._dst_lib = output_lib 440*3117ece4Schristos self._external_xxhash = external_xxhash 441*3117ece4Schristos self._xxh64_state = xxh64_state 442*3117ece4Schristos self._xxh64_prefix = xxh64_prefix 443*3117ece4Schristos self._rewritten_includes = rewritten_includes 444*3117ece4Schristos self._defs = defs 445*3117ece4Schristos self._replaces = replaces 446*3117ece4Schristos self._undefs = undefs 447*3117ece4Schristos self._excludes = excludes 448*3117ece4Schristos self._seds = seds 449*3117ece4Schristos self._spdx = spdx 450*3117ece4Schristos 451*3117ece4Schristos def _dst_lib_file_paths(self): 452*3117ece4Schristos """ 453*3117ece4Schristos Yields all the file paths in the dst_lib. 454*3117ece4Schristos """ 455*3117ece4Schristos for root, dirname, filenames in os.walk(self._dst_lib): 456*3117ece4Schristos for filename in filenames: 457*3117ece4Schristos filepath = os.path.join(root, filename) 458*3117ece4Schristos yield filepath 459*3117ece4Schristos 460*3117ece4Schristos def _log(self, *args, **kwargs): 461*3117ece4Schristos print(*args, **kwargs) 462*3117ece4Schristos 463*3117ece4Schristos def _copy_file(self, lib_path): 464*3117ece4Schristos suffixes = [".c", ".h", ".S"] 465*3117ece4Schristos if not any((lib_path.endswith(suffix) for suffix in suffixes)): 466*3117ece4Schristos return 467*3117ece4Schristos if lib_path in SKIPPED_FILES: 468*3117ece4Schristos self._log(f"\tSkipping file: {lib_path}") 469*3117ece4Schristos return 470*3117ece4Schristos if self._external_xxhash and lib_path in XXHASH_FILES: 471*3117ece4Schristos self._log(f"\tSkipping xxhash file: {lib_path}") 472*3117ece4Schristos return 473*3117ece4Schristos 474*3117ece4Schristos src_path = os.path.join(self._src_lib, lib_path) 475*3117ece4Schristos dst_path = os.path.join(self._dst_lib, lib_path) 476*3117ece4Schristos self._log(f"\tCopying: {src_path} -> {dst_path}") 477*3117ece4Schristos shutil.copyfile(src_path, dst_path) 478*3117ece4Schristos 479*3117ece4Schristos def _copy_source_lib(self): 480*3117ece4Schristos self._log("Copying source library into output library") 481*3117ece4Schristos 482*3117ece4Schristos assert os.path.exists(self._src_lib) 483*3117ece4Schristos os.makedirs(self._dst_lib, exist_ok=True) 484*3117ece4Schristos self._copy_file("zstd.h") 485*3117ece4Schristos self._copy_file("zstd_errors.h") 486*3117ece4Schristos for subdir in INCLUDED_SUBDIRS: 487*3117ece4Schristos src_dir = os.path.join(self._src_lib, subdir) 488*3117ece4Schristos dst_dir = os.path.join(self._dst_lib, subdir) 489*3117ece4Schristos 490*3117ece4Schristos assert os.path.exists(src_dir) 491*3117ece4Schristos os.makedirs(dst_dir, exist_ok=True) 492*3117ece4Schristos 493*3117ece4Schristos for filename in os.listdir(src_dir): 494*3117ece4Schristos lib_path = os.path.join(subdir, filename) 495*3117ece4Schristos self._copy_file(lib_path) 496*3117ece4Schristos 497*3117ece4Schristos def _copy_zstd_deps(self): 498*3117ece4Schristos dst_zstd_deps = os.path.join(self._dst_lib, "common", "zstd_deps.h") 499*3117ece4Schristos self._log(f"Copying zstd_deps: {self._zstd_deps} -> {dst_zstd_deps}") 500*3117ece4Schristos shutil.copyfile(self._zstd_deps, dst_zstd_deps) 501*3117ece4Schristos 502*3117ece4Schristos def _copy_mem(self): 503*3117ece4Schristos dst_mem = os.path.join(self._dst_lib, "common", "mem.h") 504*3117ece4Schristos self._log(f"Copying mem: {self._mem} -> {dst_mem}") 505*3117ece4Schristos shutil.copyfile(self._mem, dst_mem) 506*3117ece4Schristos 507*3117ece4Schristos def _hardwire_preprocessor(self, name: str, value: Optional[str] = None, undef=False): 508*3117ece4Schristos """ 509*3117ece4Schristos If value=None then hardwire that it is defined, but not what the value is. 510*3117ece4Schristos If undef=True then value must be None. 511*3117ece4Schristos If value='' then the macro is defined to '' exactly. 512*3117ece4Schristos """ 513*3117ece4Schristos assert not (undef and value is not None) 514*3117ece4Schristos for filepath in self._dst_lib_file_paths(): 515*3117ece4Schristos file = FileLines(filepath) 516*3117ece4Schristos 517*3117ece4Schristos def _hardwire_defines(self): 518*3117ece4Schristos self._log("Hardwiring macros") 519*3117ece4Schristos partial_preprocessor = PartialPreprocessor(self._defs, self._replaces, self._undefs) 520*3117ece4Schristos for filepath in self._dst_lib_file_paths(): 521*3117ece4Schristos partial_preprocessor.preprocess(filepath) 522*3117ece4Schristos 523*3117ece4Schristos def _remove_excludes(self): 524*3117ece4Schristos self._log("Removing excluded sections") 525*3117ece4Schristos for exclude in self._excludes: 526*3117ece4Schristos self._log(f"\tRemoving excluded sections for: {exclude}") 527*3117ece4Schristos begin_re = re.compile(f"BEGIN {exclude}") 528*3117ece4Schristos end_re = re.compile(f"END {exclude}") 529*3117ece4Schristos for filepath in self._dst_lib_file_paths(): 530*3117ece4Schristos file = FileLines(filepath) 531*3117ece4Schristos outlines = [] 532*3117ece4Schristos skipped = [] 533*3117ece4Schristos emit = True 534*3117ece4Schristos for line in file.lines: 535*3117ece4Schristos if emit and begin_re.search(line) is not None: 536*3117ece4Schristos assert end_re.search(line) is None 537*3117ece4Schristos emit = False 538*3117ece4Schristos if emit: 539*3117ece4Schristos outlines.append(line) 540*3117ece4Schristos else: 541*3117ece4Schristos skipped.append(line) 542*3117ece4Schristos if end_re.search(line) is not None: 543*3117ece4Schristos assert begin_re.search(line) is None 544*3117ece4Schristos self._log(f"\t\tRemoving excluded section: {exclude}") 545*3117ece4Schristos for s in skipped: 546*3117ece4Schristos self._log(f"\t\t\t- {s}") 547*3117ece4Schristos emit = True 548*3117ece4Schristos skipped = [] 549*3117ece4Schristos if not emit: 550*3117ece4Schristos raise RuntimeError("Excluded section unfinished!") 551*3117ece4Schristos file.lines = outlines 552*3117ece4Schristos file.write() 553*3117ece4Schristos 554*3117ece4Schristos def _rewrite_include(self, original, rewritten): 555*3117ece4Schristos self._log(f"\tRewriting include: {original} -> {rewritten}") 556*3117ece4Schristos regex = re.compile(f"\\s*#\\s*include\\s*(?P<include>{original})") 557*3117ece4Schristos for filepath in self._dst_lib_file_paths(): 558*3117ece4Schristos file = FileLines(filepath) 559*3117ece4Schristos for i, line in enumerate(file.lines): 560*3117ece4Schristos match = regex.match(line) 561*3117ece4Schristos if match is None: 562*3117ece4Schristos continue 563*3117ece4Schristos s = match.start('include') 564*3117ece4Schristos e = match.end('include') 565*3117ece4Schristos file.lines[i] = line[:s] + rewritten + line[e:] 566*3117ece4Schristos file.write() 567*3117ece4Schristos 568*3117ece4Schristos def _rewrite_includes(self): 569*3117ece4Schristos self._log("Rewriting includes") 570*3117ece4Schristos for original, rewritten in self._rewritten_includes: 571*3117ece4Schristos self._rewrite_include(original, rewritten) 572*3117ece4Schristos 573*3117ece4Schristos def _replace_xxh64_prefix(self): 574*3117ece4Schristos if self._xxh64_prefix is None: 575*3117ece4Schristos return 576*3117ece4Schristos self._log(f"Replacing XXH64 prefix with {self._xxh64_prefix}") 577*3117ece4Schristos replacements = [] 578*3117ece4Schristos if self._xxh64_state is not None: 579*3117ece4Schristos replacements.append( 580*3117ece4Schristos (re.compile(r"([^\w]|^)(?P<orig>XXH64_state_t)([^\w]|$)"), self._xxh64_state) 581*3117ece4Schristos ) 582*3117ece4Schristos if self._xxh64_prefix is not None: 583*3117ece4Schristos replacements.append( 584*3117ece4Schristos (re.compile(r"([^\w]|^)(?P<orig>XXH64)[\(_]"), self._xxh64_prefix) 585*3117ece4Schristos ) 586*3117ece4Schristos for filepath in self._dst_lib_file_paths(): 587*3117ece4Schristos file = FileLines(filepath) 588*3117ece4Schristos for i, line in enumerate(file.lines): 589*3117ece4Schristos modified = False 590*3117ece4Schristos for regex, replacement in replacements: 591*3117ece4Schristos match = regex.search(line) 592*3117ece4Schristos while match is not None: 593*3117ece4Schristos modified = True 594*3117ece4Schristos b = match.start('orig') 595*3117ece4Schristos e = match.end('orig') 596*3117ece4Schristos line = line[:b] + replacement + line[e:] 597*3117ece4Schristos match = regex.search(line) 598*3117ece4Schristos if modified: 599*3117ece4Schristos self._log(f"\t- {file.lines[i][:-1]}") 600*3117ece4Schristos self._log(f"\t+ {line[:-1]}") 601*3117ece4Schristos file.lines[i] = line 602*3117ece4Schristos file.write() 603*3117ece4Schristos 604*3117ece4Schristos def _parse_sed(self, sed): 605*3117ece4Schristos assert sed[0] == 's' 606*3117ece4Schristos delim = sed[1] 607*3117ece4Schristos match = re.fullmatch(f's{delim}(.+){delim}(.*){delim}(.*)', sed) 608*3117ece4Schristos assert match is not None 609*3117ece4Schristos regex = re.compile(match.group(1)) 610*3117ece4Schristos format_str = match.group(2) 611*3117ece4Schristos is_global = match.group(3) == 'g' 612*3117ece4Schristos return regex, format_str, is_global 613*3117ece4Schristos 614*3117ece4Schristos def _process_sed(self, sed): 615*3117ece4Schristos self._log(f"Processing sed: {sed}") 616*3117ece4Schristos regex, format_str, is_global = self._parse_sed(sed) 617*3117ece4Schristos 618*3117ece4Schristos for filepath in self._dst_lib_file_paths(): 619*3117ece4Schristos file = FileLines(filepath) 620*3117ece4Schristos for i, line in enumerate(file.lines): 621*3117ece4Schristos modified = False 622*3117ece4Schristos while True: 623*3117ece4Schristos match = regex.search(line) 624*3117ece4Schristos if match is None: 625*3117ece4Schristos break 626*3117ece4Schristos replacement = format_str.format(match.groups(''), match.groupdict('')) 627*3117ece4Schristos b = match.start() 628*3117ece4Schristos e = match.end() 629*3117ece4Schristos line = line[:b] + replacement + line[e:] 630*3117ece4Schristos modified = True 631*3117ece4Schristos if not is_global: 632*3117ece4Schristos break 633*3117ece4Schristos if modified: 634*3117ece4Schristos self._log(f"\t- {file.lines[i][:-1]}") 635*3117ece4Schristos self._log(f"\t+ {line[:-1]}") 636*3117ece4Schristos file.lines[i] = line 637*3117ece4Schristos file.write() 638*3117ece4Schristos 639*3117ece4Schristos def _process_seds(self): 640*3117ece4Schristos self._log("Processing seds") 641*3117ece4Schristos for sed in self._seds: 642*3117ece4Schristos self._process_sed(sed) 643*3117ece4Schristos 644*3117ece4Schristos def _process_spdx(self): 645*3117ece4Schristos if not self._spdx: 646*3117ece4Schristos return 647*3117ece4Schristos self._log("Processing spdx") 648*3117ece4Schristos SPDX_C = "// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause\n" 649*3117ece4Schristos SPDX_H_S = "/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */\n" 650*3117ece4Schristos for filepath in self._dst_lib_file_paths(): 651*3117ece4Schristos file = FileLines(filepath) 652*3117ece4Schristos if file.lines[0] == SPDX_C or file.lines[0] == SPDX_H_S: 653*3117ece4Schristos continue 654*3117ece4Schristos for line in file.lines: 655*3117ece4Schristos if "SPDX-License-Identifier" in line: 656*3117ece4Schristos raise RuntimeError(f"Unexpected SPDX license identifier: {file.filename} {repr(line)}") 657*3117ece4Schristos if file.filename.endswith(".c"): 658*3117ece4Schristos file.lines.insert(0, SPDX_C) 659*3117ece4Schristos elif file.filename.endswith(".h") or file.filename.endswith(".S"): 660*3117ece4Schristos file.lines.insert(0, SPDX_H_S) 661*3117ece4Schristos else: 662*3117ece4Schristos raise RuntimeError(f"Unexpected file extension: {file.filename}") 663*3117ece4Schristos file.write() 664*3117ece4Schristos 665*3117ece4Schristos 666*3117ece4Schristos 667*3117ece4Schristos def go(self): 668*3117ece4Schristos self._copy_source_lib() 669*3117ece4Schristos self._copy_zstd_deps() 670*3117ece4Schristos self._copy_mem() 671*3117ece4Schristos self._hardwire_defines() 672*3117ece4Schristos self._remove_excludes() 673*3117ece4Schristos self._rewrite_includes() 674*3117ece4Schristos self._replace_xxh64_prefix() 675*3117ece4Schristos self._process_seds() 676*3117ece4Schristos self._process_spdx() 677*3117ece4Schristos 678*3117ece4Schristos 679*3117ece4Schristosdef parse_optional_pair(defines: [str]) -> [(str, Optional[str])]: 680*3117ece4Schristos output = [] 681*3117ece4Schristos for define in defines: 682*3117ece4Schristos parsed = define.split('=') 683*3117ece4Schristos if len(parsed) == 1: 684*3117ece4Schristos output.append((parsed[0], None)) 685*3117ece4Schristos elif len(parsed) == 2: 686*3117ece4Schristos output.append((parsed[0], parsed[1])) 687*3117ece4Schristos else: 688*3117ece4Schristos raise RuntimeError(f"Bad define: {define}") 689*3117ece4Schristos return output 690*3117ece4Schristos 691*3117ece4Schristos 692*3117ece4Schristosdef parse_pair(rewritten_includes: [str]) -> [(str, str)]: 693*3117ece4Schristos output = [] 694*3117ece4Schristos for rewritten_include in rewritten_includes: 695*3117ece4Schristos parsed = rewritten_include.split('=') 696*3117ece4Schristos if len(parsed) == 2: 697*3117ece4Schristos output.append((parsed[0], parsed[1])) 698*3117ece4Schristos else: 699*3117ece4Schristos raise RuntimeError(f"Bad rewritten include: {rewritten_include}") 700*3117ece4Schristos return output 701*3117ece4Schristos 702*3117ece4Schristos 703*3117ece4Schristos 704*3117ece4Schristosdef main(name, args): 705*3117ece4Schristos parser = argparse.ArgumentParser(prog=name) 706*3117ece4Schristos parser.add_argument("--zstd-deps", default="zstd_deps.h", help="Zstd dependencies file") 707*3117ece4Schristos parser.add_argument("--mem", default="mem.h", help="Memory module") 708*3117ece4Schristos parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library") 709*3117ece4Schristos parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library") 710*3117ece4Schristos parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash='<xxhash.h>'. If set xxhash is not included.") 711*3117ece4Schristos parser.add_argument("--xxh64-state", default=None, help="Alternate XXH64 state type (excluding _) e.g. --xxh64-state='struct xxh64_state'") 712*3117ece4Schristos parser.add_argument("--xxh64-prefix", default=None, help="Alternate XXH64 function prefix (excluding _) e.g. --xxh64-prefix=xxh64") 713*3117ece4Schristos parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '<stddef\\.h>=<linux/types.h>')") 714*3117ece4Schristos parser.add_argument("--sed", default=[], dest="seds", action="append", help="Apply a sed replacement. Format: `s/REGEX/FORMAT/[g]`. REGEX is a Python regex. FORMAT is a Python format string formatted by the regex dict.") 715*3117ece4Schristos parser.add_argument("--spdx", action="store_true", help="Add SPDX License Identifiers") 716*3117ece4Schristos parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)") 717*3117ece4Schristos parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed multiple times)") 718*3117ece4Schristos parser.add_argument("-R", "--replace", default=[], dest="replaces", action="append", help="Pre-define this macro and replace the first ifndef block with its definition") 719*3117ece4Schristos parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN <EXCLUDE>' and 'END <EXCLUDE>'") 720*3117ece4Schristos args = parser.parse_args(args) 721*3117ece4Schristos 722*3117ece4Schristos # Always remove threading 723*3117ece4Schristos if "ZSTD_MULTITHREAD" not in args.undefs: 724*3117ece4Schristos args.undefs.append("ZSTD_MULTITHREAD") 725*3117ece4Schristos 726*3117ece4Schristos args.defs = parse_optional_pair(args.defs) 727*3117ece4Schristos for name, _ in args.defs: 728*3117ece4Schristos if name in args.undefs: 729*3117ece4Schristos raise RuntimeError(f"{name} is both defined and undefined!") 730*3117ece4Schristos 731*3117ece4Schristos # Always set tracing to 0 732*3117ece4Schristos if "ZSTD_NO_TRACE" not in (arg[0] for arg in args.defs): 733*3117ece4Schristos args.defs.append(("ZSTD_NO_TRACE", None)) 734*3117ece4Schristos args.defs.append(("ZSTD_TRACE", "0")) 735*3117ece4Schristos 736*3117ece4Schristos args.replaces = parse_pair(args.replaces) 737*3117ece4Schristos for name, _ in args.replaces: 738*3117ece4Schristos if name in args.undefs or name in args.defs: 739*3117ece4Schristos raise RuntimeError(f"{name} is both replaced and (un)defined!") 740*3117ece4Schristos 741*3117ece4Schristos args.rewritten_includes = parse_pair(args.rewritten_includes) 742*3117ece4Schristos 743*3117ece4Schristos external_xxhash = False 744*3117ece4Schristos if args.xxhash is not None: 745*3117ece4Schristos external_xxhash = True 746*3117ece4Schristos args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash)) 747*3117ece4Schristos 748*3117ece4Schristos if args.xxh64_prefix is not None: 749*3117ece4Schristos if not external_xxhash: 750*3117ece4Schristos raise RuntimeError("--xxh64-prefix may only be used with --xxhash provided") 751*3117ece4Schristos 752*3117ece4Schristos if args.xxh64_state is not None: 753*3117ece4Schristos if not external_xxhash: 754*3117ece4Schristos raise RuntimeError("--xxh64-state may only be used with --xxhash provided") 755*3117ece4Schristos 756*3117ece4Schristos Freestanding( 757*3117ece4Schristos args.zstd_deps, 758*3117ece4Schristos args.mem, 759*3117ece4Schristos args.source_lib, 760*3117ece4Schristos args.output_lib, 761*3117ece4Schristos external_xxhash, 762*3117ece4Schristos args.xxh64_state, 763*3117ece4Schristos args.xxh64_prefix, 764*3117ece4Schristos args.rewritten_includes, 765*3117ece4Schristos args.defs, 766*3117ece4Schristos args.replaces, 767*3117ece4Schristos args.undefs, 768*3117ece4Schristos args.excludes, 769*3117ece4Schristos args.seds, 770*3117ece4Schristos args.spdx, 771*3117ece4Schristos ).go() 772*3117ece4Schristos 773*3117ece4Schristosif __name__ == "__main__": 774*3117ece4Schristos main(sys.argv[0], sys.argv[1:]) 775