xref: /llvm-project/clang/utils/ABITest/Enumeration.py (revision dd3c26a045c081620375a878159f536758baba6e)
1"""Utilities for enumeration of finite and countably infinite sets.
2"""
3from __future__ import absolute_import, division, print_function
4
5###
6# Countable iteration
7
8# Simplifies some calculations
9class Aleph0(int):
10    _singleton = None
11
12    def __new__(type):
13        if type._singleton is None:
14            type._singleton = int.__new__(type)
15        return type._singleton
16
17    def __repr__(self):
18        return "<aleph0>"
19
20    def __str__(self):
21        return "inf"
22
23    def __cmp__(self, b):
24        return 1
25
26    def __sub__(self, b):
27        raise ValueError("Cannot subtract aleph0")
28
29    __rsub__ = __sub__
30
31    def __add__(self, b):
32        return self
33
34    __radd__ = __add__
35
36    def __mul__(self, b):
37        if b == 0:
38            return b
39        return self
40
41    __rmul__ = __mul__
42
43    def __floordiv__(self, b):
44        if b == 0:
45            raise ZeroDivisionError
46        return self
47
48    __rfloordiv__ = __floordiv__
49    __truediv__ = __floordiv__
50    __rtuediv__ = __floordiv__
51    __div__ = __floordiv__
52    __rdiv__ = __floordiv__
53
54    def __pow__(self, b):
55        if b == 0:
56            return 1
57        return self
58
59
60aleph0 = Aleph0()
61
62
63def base(line):
64    return line * (line + 1) // 2
65
66
67def pairToN(pair):
68    x, y = pair
69    line, index = x + y, y
70    return base(line) + index
71
72
73def getNthPairInfo(N):
74    # Avoid various singularities
75    if N == 0:
76        return (0, 0)
77
78    # Gallop to find bounds for line
79    line = 1
80    next = 2
81    while base(next) <= N:
82        line = next
83        next = line << 1
84
85    # Binary search for starting line
86    lo = line
87    hi = line << 1
88    while lo + 1 != hi:
89        # assert base(lo) <= N < base(hi)
90        mid = (lo + hi) >> 1
91        if base(mid) <= N:
92            lo = mid
93        else:
94            hi = mid
95
96    line = lo
97    return line, N - base(line)
98
99
100def getNthPair(N):
101    line, index = getNthPairInfo(N)
102    return (line - index, index)
103
104
105def getNthPairBounded(N, W=aleph0, H=aleph0, useDivmod=False):
106    """getNthPairBounded(N, W, H) -> (x, y)
107
108    Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
109
110    if W <= 0 or H <= 0:
111        raise ValueError("Invalid bounds")
112    elif N >= W * H:
113        raise ValueError("Invalid input (out of bounds)")
114
115    # Simple case...
116    if W is aleph0 and H is aleph0:
117        return getNthPair(N)
118
119    # Otherwise simplify by assuming W < H
120    if H < W:
121        x, y = getNthPairBounded(N, H, W, useDivmod=useDivmod)
122        return y, x
123
124    if useDivmod:
125        return N % W, N // W
126    else:
127        # Conceptually we want to slide a diagonal line across a
128        # rectangle. This gives more interesting results for large
129        # bounds than using divmod.
130
131        # If in lower left, just return as usual
132        cornerSize = base(W)
133        if N < cornerSize:
134            return getNthPair(N)
135
136        # Otherwise if in upper right, subtract from corner
137        if H is not aleph0:
138            M = W * H - N - 1
139            if M < cornerSize:
140                x, y = getNthPair(M)
141                return (W - 1 - x, H - 1 - y)
142
143        # Otherwise, compile line and index from number of times we
144        # wrap.
145        N = N - cornerSize
146        index, offset = N % W, N // W
147        # p = (W-1, 1+offset) + (-1,1)*index
148        return (W - 1 - index, 1 + offset + index)
149
150
151def getNthPairBoundedChecked(
152    N, W=aleph0, H=aleph0, useDivmod=False, GNP=getNthPairBounded
153):
154    x, y = GNP(N, W, H, useDivmod)
155    assert 0 <= x < W and 0 <= y < H
156    return x, y
157
158
159def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
160    """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
161
162    Return the N-th W-tuple, where for 0 <= x_i < H."""
163
164    if useLeftToRight:
165        elts = [None] * W
166        for i in range(W):
167            elts[i], N = getNthPairBounded(N, H)
168        return tuple(elts)
169    else:
170        if W == 0:
171            return ()
172        elif W == 1:
173            return (N,)
174        elif W == 2:
175            return getNthPairBounded(N, H, H)
176        else:
177            LW, RW = W // 2, W - (W // 2)
178            L, R = getNthPairBounded(N, H**LW, H**RW)
179            return getNthNTuple(
180                L, LW, H=H, useLeftToRight=useLeftToRight
181            ) + getNthNTuple(R, RW, H=H, useLeftToRight=useLeftToRight)
182
183
184def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
185    t = GNT(N, W, H, useLeftToRight)
186    assert len(t) == W
187    for i in t:
188        assert i < H
189    return t
190
191
192def getNthTuple(
193    N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False
194):
195    """getNthTuple(N, maxSize, maxElement) -> x
196
197    Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
198    y < maxElement."""
199
200    # All zero sized tuples are isomorphic, don't ya know.
201    if N == 0:
202        return ()
203    N -= 1
204    if maxElement is not aleph0:
205        if maxSize is aleph0:
206            raise NotImplementedError("Max element size without max size unhandled")
207        bounds = [maxElement**i for i in range(1, maxSize + 1)]
208        S, M = getNthPairVariableBounds(N, bounds)
209    else:
210        S, M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
211    return getNthNTuple(M, S + 1, maxElement, useLeftToRight=useLeftToRight)
212
213
214def getNthTupleChecked(
215    N,
216    maxSize=aleph0,
217    maxElement=aleph0,
218    useDivmod=False,
219    useLeftToRight=False,
220    GNT=getNthTuple,
221):
222    # FIXME: maxsize is inclusive
223    t = GNT(N, maxSize, maxElement, useDivmod, useLeftToRight)
224    assert len(t) <= maxSize
225    for i in t:
226        assert i < maxElement
227    return t
228
229
230def getNthPairVariableBounds(N, bounds):
231    """getNthPairVariableBounds(N, bounds) -> (x, y)
232
233    Given a finite list of bounds (which may be finite or aleph0),
234    return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
235    bounds[x]."""
236
237    if not bounds:
238        raise ValueError("Invalid bounds")
239    if not (0 <= N < sum(bounds)):
240        raise ValueError("Invalid input (out of bounds)")
241
242    level = 0
243    active = list(range(len(bounds)))
244    active.sort(key=lambda i: bounds[i])
245    prevLevel = 0
246    for i, index in enumerate(active):
247        level = bounds[index]
248        W = len(active) - i
249        if level is aleph0:
250            H = aleph0
251        else:
252            H = level - prevLevel
253        levelSize = W * H
254        if N < levelSize:  # Found the level
255            idelta, delta = getNthPairBounded(N, W, H)
256            return active[i + idelta], prevLevel + delta
257        else:
258            N -= levelSize
259            prevLevel = level
260    else:
261        raise RuntimError("Unexpected loop completion")
262
263
264def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
265    x, y = GNVP(N, bounds)
266    assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
267    return (x, y)
268
269
270###
271
272
273def testPairs():
274    W = 3
275    H = 6
276    a = [["  " for x in range(10)] for y in range(10)]
277    b = [["  " for x in range(10)] for y in range(10)]
278    for i in range(min(W * H, 40)):
279        x, y = getNthPairBounded(i, W, H)
280        x2, y2 = getNthPairBounded(i, W, H, useDivmod=True)
281        print(i, (x, y), (x2, y2))
282        a[y][x] = "%2d" % i
283        b[y2][x2] = "%2d" % i
284
285    print("-- a --")
286    for ln in a[::-1]:
287        if "".join(ln).strip():
288            print("  ".join(ln))
289    print("-- b --")
290    for ln in b[::-1]:
291        if "".join(ln).strip():
292            print("  ".join(ln))
293
294
295def testPairsVB():
296    bounds = [2, 2, 4, aleph0, 5, aleph0]
297    a = [["  " for x in range(15)] for y in range(15)]
298    b = [["  " for x in range(15)] for y in range(15)]
299    for i in range(min(sum(bounds), 40)):
300        x, y = getNthPairVariableBounds(i, bounds)
301        print(i, (x, y))
302        a[y][x] = "%2d" % i
303
304    print("-- a --")
305    for ln in a[::-1]:
306        if "".join(ln).strip():
307            print("  ".join(ln))
308
309
310###
311
312# Toggle to use checked versions of enumeration routines.
313if False:
314    getNthPairVariableBounds = getNthPairVariableBoundsChecked
315    getNthPairBounded = getNthPairBoundedChecked
316    getNthNTuple = getNthNTupleChecked
317    getNthTuple = getNthTupleChecked
318
319if __name__ == "__main__":
320    testPairs()
321
322    testPairsVB()
323