1"""Utilities for enumeration of finite and countably infinite sets. 2""" 3from __future__ import absolute_import, division, print_function 4 5### 6# Countable iteration 7 8# Simplifies some calculations 9class Aleph0(int): 10 _singleton = None 11 12 def __new__(type): 13 if type._singleton is None: 14 type._singleton = int.__new__(type) 15 return type._singleton 16 17 def __repr__(self): 18 return "<aleph0>" 19 20 def __str__(self): 21 return "inf" 22 23 def __cmp__(self, b): 24 return 1 25 26 def __sub__(self, b): 27 raise ValueError("Cannot subtract aleph0") 28 29 __rsub__ = __sub__ 30 31 def __add__(self, b): 32 return self 33 34 __radd__ = __add__ 35 36 def __mul__(self, b): 37 if b == 0: 38 return b 39 return self 40 41 __rmul__ = __mul__ 42 43 def __floordiv__(self, b): 44 if b == 0: 45 raise ZeroDivisionError 46 return self 47 48 __rfloordiv__ = __floordiv__ 49 __truediv__ = __floordiv__ 50 __rtuediv__ = __floordiv__ 51 __div__ = __floordiv__ 52 __rdiv__ = __floordiv__ 53 54 def __pow__(self, b): 55 if b == 0: 56 return 1 57 return self 58 59 60aleph0 = Aleph0() 61 62 63def base(line): 64 return line * (line + 1) // 2 65 66 67def pairToN(pair): 68 x, y = pair 69 line, index = x + y, y 70 return base(line) + index 71 72 73def getNthPairInfo(N): 74 # Avoid various singularities 75 if N == 0: 76 return (0, 0) 77 78 # Gallop to find bounds for line 79 line = 1 80 next = 2 81 while base(next) <= N: 82 line = next 83 next = line << 1 84 85 # Binary search for starting line 86 lo = line 87 hi = line << 1 88 while lo + 1 != hi: 89 # assert base(lo) <= N < base(hi) 90 mid = (lo + hi) >> 1 91 if base(mid) <= N: 92 lo = mid 93 else: 94 hi = mid 95 96 line = lo 97 return line, N - base(line) 98 99 100def getNthPair(N): 101 line, index = getNthPairInfo(N) 102 return (line - index, index) 103 104 105def getNthPairBounded(N, W=aleph0, H=aleph0, useDivmod=False): 106 """getNthPairBounded(N, W, H) -> (x, y) 107 108 Return the N-th pair such that 0 <= x < W and 0 <= y < H.""" 109 110 if W <= 0 or H <= 0: 111 raise ValueError("Invalid bounds") 112 elif N >= W * H: 113 raise ValueError("Invalid input (out of bounds)") 114 115 # Simple case... 116 if W is aleph0 and H is aleph0: 117 return getNthPair(N) 118 119 # Otherwise simplify by assuming W < H 120 if H < W: 121 x, y = getNthPairBounded(N, H, W, useDivmod=useDivmod) 122 return y, x 123 124 if useDivmod: 125 return N % W, N // W 126 else: 127 # Conceptually we want to slide a diagonal line across a 128 # rectangle. This gives more interesting results for large 129 # bounds than using divmod. 130 131 # If in lower left, just return as usual 132 cornerSize = base(W) 133 if N < cornerSize: 134 return getNthPair(N) 135 136 # Otherwise if in upper right, subtract from corner 137 if H is not aleph0: 138 M = W * H - N - 1 139 if M < cornerSize: 140 x, y = getNthPair(M) 141 return (W - 1 - x, H - 1 - y) 142 143 # Otherwise, compile line and index from number of times we 144 # wrap. 145 N = N - cornerSize 146 index, offset = N % W, N // W 147 # p = (W-1, 1+offset) + (-1,1)*index 148 return (W - 1 - index, 1 + offset + index) 149 150 151def getNthPairBoundedChecked( 152 N, W=aleph0, H=aleph0, useDivmod=False, GNP=getNthPairBounded 153): 154 x, y = GNP(N, W, H, useDivmod) 155 assert 0 <= x < W and 0 <= y < H 156 return x, y 157 158 159def getNthNTuple(N, W, H=aleph0, useLeftToRight=False): 160 """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W) 161 162 Return the N-th W-tuple, where for 0 <= x_i < H.""" 163 164 if useLeftToRight: 165 elts = [None] * W 166 for i in range(W): 167 elts[i], N = getNthPairBounded(N, H) 168 return tuple(elts) 169 else: 170 if W == 0: 171 return () 172 elif W == 1: 173 return (N,) 174 elif W == 2: 175 return getNthPairBounded(N, H, H) 176 else: 177 LW, RW = W // 2, W - (W // 2) 178 L, R = getNthPairBounded(N, H**LW, H**RW) 179 return getNthNTuple( 180 L, LW, H=H, useLeftToRight=useLeftToRight 181 ) + getNthNTuple(R, RW, H=H, useLeftToRight=useLeftToRight) 182 183 184def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple): 185 t = GNT(N, W, H, useLeftToRight) 186 assert len(t) == W 187 for i in t: 188 assert i < H 189 return t 190 191 192def getNthTuple( 193 N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False 194): 195 """getNthTuple(N, maxSize, maxElement) -> x 196 197 Return the N-th tuple where len(x) < maxSize and for y in x, 0 <= 198 y < maxElement.""" 199 200 # All zero sized tuples are isomorphic, don't ya know. 201 if N == 0: 202 return () 203 N -= 1 204 if maxElement is not aleph0: 205 if maxSize is aleph0: 206 raise NotImplementedError("Max element size without max size unhandled") 207 bounds = [maxElement**i for i in range(1, maxSize + 1)] 208 S, M = getNthPairVariableBounds(N, bounds) 209 else: 210 S, M = getNthPairBounded(N, maxSize, useDivmod=useDivmod) 211 return getNthNTuple(M, S + 1, maxElement, useLeftToRight=useLeftToRight) 212 213 214def getNthTupleChecked( 215 N, 216 maxSize=aleph0, 217 maxElement=aleph0, 218 useDivmod=False, 219 useLeftToRight=False, 220 GNT=getNthTuple, 221): 222 # FIXME: maxsize is inclusive 223 t = GNT(N, maxSize, maxElement, useDivmod, useLeftToRight) 224 assert len(t) <= maxSize 225 for i in t: 226 assert i < maxElement 227 return t 228 229 230def getNthPairVariableBounds(N, bounds): 231 """getNthPairVariableBounds(N, bounds) -> (x, y) 232 233 Given a finite list of bounds (which may be finite or aleph0), 234 return the N-th pair such that 0 <= x < len(bounds) and 0 <= y < 235 bounds[x].""" 236 237 if not bounds: 238 raise ValueError("Invalid bounds") 239 if not (0 <= N < sum(bounds)): 240 raise ValueError("Invalid input (out of bounds)") 241 242 level = 0 243 active = list(range(len(bounds))) 244 active.sort(key=lambda i: bounds[i]) 245 prevLevel = 0 246 for i, index in enumerate(active): 247 level = bounds[index] 248 W = len(active) - i 249 if level is aleph0: 250 H = aleph0 251 else: 252 H = level - prevLevel 253 levelSize = W * H 254 if N < levelSize: # Found the level 255 idelta, delta = getNthPairBounded(N, W, H) 256 return active[i + idelta], prevLevel + delta 257 else: 258 N -= levelSize 259 prevLevel = level 260 else: 261 raise RuntimError("Unexpected loop completion") 262 263 264def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds): 265 x, y = GNVP(N, bounds) 266 assert 0 <= x < len(bounds) and 0 <= y < bounds[x] 267 return (x, y) 268 269 270### 271 272 273def testPairs(): 274 W = 3 275 H = 6 276 a = [[" " for x in range(10)] for y in range(10)] 277 b = [[" " for x in range(10)] for y in range(10)] 278 for i in range(min(W * H, 40)): 279 x, y = getNthPairBounded(i, W, H) 280 x2, y2 = getNthPairBounded(i, W, H, useDivmod=True) 281 print(i, (x, y), (x2, y2)) 282 a[y][x] = "%2d" % i 283 b[y2][x2] = "%2d" % i 284 285 print("-- a --") 286 for ln in a[::-1]: 287 if "".join(ln).strip(): 288 print(" ".join(ln)) 289 print("-- b --") 290 for ln in b[::-1]: 291 if "".join(ln).strip(): 292 print(" ".join(ln)) 293 294 295def testPairsVB(): 296 bounds = [2, 2, 4, aleph0, 5, aleph0] 297 a = [[" " for x in range(15)] for y in range(15)] 298 b = [[" " for x in range(15)] for y in range(15)] 299 for i in range(min(sum(bounds), 40)): 300 x, y = getNthPairVariableBounds(i, bounds) 301 print(i, (x, y)) 302 a[y][x] = "%2d" % i 303 304 print("-- a --") 305 for ln in a[::-1]: 306 if "".join(ln).strip(): 307 print(" ".join(ln)) 308 309 310### 311 312# Toggle to use checked versions of enumeration routines. 313if False: 314 getNthPairVariableBounds = getNthPairVariableBoundsChecked 315 getNthPairBounded = getNthPairBoundedChecked 316 getNthNTuple = getNthNTupleChecked 317 getNthTuple = getNthTupleChecked 318 319if __name__ == "__main__": 320 testPairs() 321 322 testPairsVB() 323