1"""Flexible enumeration of C types.""" 2from __future__ import division, print_function 3 4from Enumeration import * 5 6# TODO: 7 8# - struct improvements (flexible arrays, packed & 9# unpacked, alignment) 10# - objective-c qualified id 11# - anonymous / transparent unions 12# - VLAs 13# - block types 14# - K&R functions 15# - pass arguments of different types (test extension, transparent union) 16# - varargs 17 18### 19# Actual type types 20 21 22class Type(object): 23 def isBitField(self): 24 return False 25 26 def isPaddingBitField(self): 27 return False 28 29 def getTypeName(self, printer): 30 name = "T%d" % len(printer.types) 31 typedef = self.getTypedefDef(name, printer) 32 printer.addDeclaration(typedef) 33 return name 34 35 36class BuiltinType(Type): 37 def __init__(self, name, size, bitFieldSize=None): 38 self.name = name 39 self.size = size 40 self.bitFieldSize = bitFieldSize 41 42 def isBitField(self): 43 return self.bitFieldSize is not None 44 45 def isPaddingBitField(self): 46 return self.bitFieldSize is 0 47 48 def getBitFieldSize(self): 49 assert self.isBitField() 50 return self.bitFieldSize 51 52 def getTypeName(self, printer): 53 return self.name 54 55 def sizeof(self): 56 return self.size 57 58 def __str__(self): 59 return self.name 60 61 62class EnumType(Type): 63 unique_id = 0 64 65 def __init__(self, index, enumerators): 66 self.index = index 67 self.enumerators = enumerators 68 self.unique_id = self.__class__.unique_id 69 self.__class__.unique_id += 1 70 71 def getEnumerators(self): 72 result = "" 73 for i, init in enumerate(self.enumerators): 74 if i > 0: 75 result = result + ", " 76 result = result + "enum%dval%d_%d" % (self.index, i, self.unique_id) 77 if init: 78 result = result + " = %s" % (init) 79 80 return result 81 82 def __str__(self): 83 return "enum { %s }" % (self.getEnumerators()) 84 85 def getTypedefDef(self, name, printer): 86 return "typedef enum %s { %s } %s;" % (name, self.getEnumerators(), name) 87 88 89class RecordType(Type): 90 def __init__(self, index, isUnion, fields): 91 self.index = index 92 self.isUnion = isUnion 93 self.fields = fields 94 self.name = None 95 96 def __str__(self): 97 def getField(t): 98 if t.isBitField(): 99 return "%s : %d;" % (t, t.getBitFieldSize()) 100 else: 101 return "%s;" % t 102 103 return "%s { %s }" % ( 104 ("struct", "union")[self.isUnion], 105 " ".join(map(getField, self.fields)), 106 ) 107 108 def getTypedefDef(self, name, printer): 109 def getField(it): 110 i, t = it 111 if t.isBitField(): 112 if t.isPaddingBitField(): 113 return "%s : 0;" % (printer.getTypeName(t),) 114 else: 115 return "%s field%d : %d;" % ( 116 printer.getTypeName(t), 117 i, 118 t.getBitFieldSize(), 119 ) 120 else: 121 return "%s field%d;" % (printer.getTypeName(t), i) 122 123 fields = [getField(f) for f in enumerate(self.fields)] 124 # Name the struct for more readable LLVM IR. 125 return "typedef %s %s { %s } %s;" % ( 126 ("struct", "union")[self.isUnion], 127 name, 128 " ".join(fields), 129 name, 130 ) 131 132 133class ArrayType(Type): 134 def __init__(self, index, isVector, elementType, size): 135 if isVector: 136 # Note that for vectors, this is the size in bytes. 137 assert size > 0 138 else: 139 assert size is None or size >= 0 140 self.index = index 141 self.isVector = isVector 142 self.elementType = elementType 143 self.size = size 144 if isVector: 145 eltSize = self.elementType.sizeof() 146 assert not (self.size % eltSize) 147 self.numElements = self.size // eltSize 148 else: 149 self.numElements = self.size 150 151 def __str__(self): 152 if self.isVector: 153 return "vector (%s)[%d]" % (self.elementType, self.size) 154 elif self.size is not None: 155 return "(%s)[%d]" % (self.elementType, self.size) 156 else: 157 return "(%s)[]" % (self.elementType,) 158 159 def getTypedefDef(self, name, printer): 160 elementName = printer.getTypeName(self.elementType) 161 if self.isVector: 162 return "typedef %s %s __attribute__ ((vector_size (%d)));" % ( 163 elementName, 164 name, 165 self.size, 166 ) 167 else: 168 if self.size is None: 169 sizeStr = "" 170 else: 171 sizeStr = str(self.size) 172 return "typedef %s %s[%s];" % (elementName, name, sizeStr) 173 174 175class ComplexType(Type): 176 def __init__(self, index, elementType): 177 self.index = index 178 self.elementType = elementType 179 180 def __str__(self): 181 return "_Complex (%s)" % (self.elementType) 182 183 def getTypedefDef(self, name, printer): 184 return "typedef _Complex %s %s;" % (printer.getTypeName(self.elementType), name) 185 186 187class FunctionType(Type): 188 def __init__(self, index, returnType, argTypes): 189 self.index = index 190 self.returnType = returnType 191 self.argTypes = argTypes 192 193 def __str__(self): 194 if self.returnType is None: 195 rt = "void" 196 else: 197 rt = str(self.returnType) 198 if not self.argTypes: 199 at = "void" 200 else: 201 at = ", ".join(map(str, self.argTypes)) 202 return "%s (*)(%s)" % (rt, at) 203 204 def getTypedefDef(self, name, printer): 205 if self.returnType is None: 206 rt = "void" 207 else: 208 rt = str(self.returnType) 209 if not self.argTypes: 210 at = "void" 211 else: 212 at = ", ".join(map(str, self.argTypes)) 213 return "typedef %s (*%s)(%s);" % (rt, name, at) 214 215 216### 217# Type enumerators 218 219 220class TypeGenerator(object): 221 def __init__(self): 222 self.cache = {} 223 224 def setCardinality(self): 225 abstract 226 227 def get(self, N): 228 T = self.cache.get(N) 229 if T is None: 230 assert 0 <= N < self.cardinality 231 T = self.cache[N] = self.generateType(N) 232 return T 233 234 def generateType(self, N): 235 abstract 236 237 238class FixedTypeGenerator(TypeGenerator): 239 def __init__(self, types): 240 TypeGenerator.__init__(self) 241 self.types = types 242 self.setCardinality() 243 244 def setCardinality(self): 245 self.cardinality = len(self.types) 246 247 def generateType(self, N): 248 return self.types[N] 249 250 251# Factorial 252def fact(n): 253 result = 1 254 while n > 0: 255 result = result * n 256 n = n - 1 257 return result 258 259 260# Compute the number of combinations (n choose k) 261def num_combinations(n, k): 262 return fact(n) // (fact(k) * fact(n - k)) 263 264 265# Enumerate the combinations choosing k elements from the list of values 266def combinations(values, k): 267 # From ActiveState Recipe 190465: Generator for permutations, 268 # combinations, selections of a sequence 269 if k == 0: 270 yield [] 271 else: 272 for i in range(len(values) - k + 1): 273 for cc in combinations(values[i + 1 :], k - 1): 274 yield [values[i]] + cc 275 276 277class EnumTypeGenerator(TypeGenerator): 278 def __init__(self, values, minEnumerators, maxEnumerators): 279 TypeGenerator.__init__(self) 280 self.values = values 281 self.minEnumerators = minEnumerators 282 self.maxEnumerators = maxEnumerators 283 self.setCardinality() 284 285 def setCardinality(self): 286 self.cardinality = 0 287 for num in range(self.minEnumerators, self.maxEnumerators + 1): 288 self.cardinality += num_combinations(len(self.values), num) 289 290 def generateType(self, n): 291 # Figure out the number of enumerators in this type 292 numEnumerators = self.minEnumerators 293 valuesCovered = 0 294 while numEnumerators < self.maxEnumerators: 295 comb = num_combinations(len(self.values), numEnumerators) 296 if valuesCovered + comb > n: 297 break 298 numEnumerators = numEnumerators + 1 299 valuesCovered += comb 300 301 # Find the requested combination of enumerators and build a 302 # type from it. 303 i = 0 304 for enumerators in combinations(self.values, numEnumerators): 305 if i == n - valuesCovered: 306 return EnumType(n, enumerators) 307 308 i = i + 1 309 310 assert False 311 312 313class ComplexTypeGenerator(TypeGenerator): 314 def __init__(self, typeGen): 315 TypeGenerator.__init__(self) 316 self.typeGen = typeGen 317 self.setCardinality() 318 319 def setCardinality(self): 320 self.cardinality = self.typeGen.cardinality 321 322 def generateType(self, N): 323 return ComplexType(N, self.typeGen.get(N)) 324 325 326class VectorTypeGenerator(TypeGenerator): 327 def __init__(self, typeGen, sizes): 328 TypeGenerator.__init__(self) 329 self.typeGen = typeGen 330 self.sizes = tuple(map(int, sizes)) 331 self.setCardinality() 332 333 def setCardinality(self): 334 self.cardinality = len(self.sizes) * self.typeGen.cardinality 335 336 def generateType(self, N): 337 S, T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality) 338 return ArrayType(N, True, self.typeGen.get(T), self.sizes[S]) 339 340 341class FixedArrayTypeGenerator(TypeGenerator): 342 def __init__(self, typeGen, sizes): 343 TypeGenerator.__init__(self) 344 self.typeGen = typeGen 345 self.sizes = tuple(size) 346 self.setCardinality() 347 348 def setCardinality(self): 349 self.cardinality = len(self.sizes) * self.typeGen.cardinality 350 351 def generateType(self, N): 352 S, T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality) 353 return ArrayType(N, false, self.typeGen.get(T), self.sizes[S]) 354 355 356class ArrayTypeGenerator(TypeGenerator): 357 def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False): 358 TypeGenerator.__init__(self) 359 self.typeGen = typeGen 360 self.useIncomplete = useIncomplete 361 self.useZero = useZero 362 self.maxSize = int(maxSize) 363 self.W = useIncomplete + useZero + self.maxSize 364 self.setCardinality() 365 366 def setCardinality(self): 367 self.cardinality = self.W * self.typeGen.cardinality 368 369 def generateType(self, N): 370 S, T = getNthPairBounded(N, self.W, self.typeGen.cardinality) 371 if self.useIncomplete: 372 if S == 0: 373 size = None 374 S = None 375 else: 376 S = S - 1 377 if S is not None: 378 if self.useZero: 379 size = S 380 else: 381 size = S + 1 382 return ArrayType(N, False, self.typeGen.get(T), size) 383 384 385class RecordTypeGenerator(TypeGenerator): 386 def __init__(self, typeGen, useUnion, maxSize): 387 TypeGenerator.__init__(self) 388 self.typeGen = typeGen 389 self.useUnion = bool(useUnion) 390 self.maxSize = int(maxSize) 391 self.setCardinality() 392 393 def setCardinality(self): 394 M = 1 + self.useUnion 395 if self.maxSize is aleph0: 396 S = aleph0 * self.typeGen.cardinality 397 else: 398 S = 0 399 for i in range(self.maxSize + 1): 400 S += M * (self.typeGen.cardinality**i) 401 self.cardinality = S 402 403 def generateType(self, N): 404 isUnion, I = False, N 405 if self.useUnion: 406 isUnion, I = (I & 1), I >> 1 407 fields = [ 408 self.typeGen.get(f) 409 for f in getNthTuple(I, self.maxSize, self.typeGen.cardinality) 410 ] 411 return RecordType(N, isUnion, fields) 412 413 414class FunctionTypeGenerator(TypeGenerator): 415 def __init__(self, typeGen, useReturn, maxSize): 416 TypeGenerator.__init__(self) 417 self.typeGen = typeGen 418 self.useReturn = useReturn 419 self.maxSize = maxSize 420 self.setCardinality() 421 422 def setCardinality(self): 423 if self.maxSize is aleph0: 424 S = aleph0 * self.typeGen.cardinality() 425 elif self.useReturn: 426 S = 0 427 for i in range(1, self.maxSize + 1 + 1): 428 S += self.typeGen.cardinality**i 429 else: 430 S = 0 431 for i in range(self.maxSize + 1): 432 S += self.typeGen.cardinality**i 433 self.cardinality = S 434 435 def generateType(self, N): 436 if self.useReturn: 437 # Skip the empty tuple 438 argIndices = getNthTuple(N + 1, self.maxSize + 1, self.typeGen.cardinality) 439 retIndex, argIndices = argIndices[0], argIndices[1:] 440 retTy = self.typeGen.get(retIndex) 441 else: 442 retTy = None 443 argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality) 444 args = [self.typeGen.get(i) for i in argIndices] 445 return FunctionType(N, retTy, args) 446 447 448class AnyTypeGenerator(TypeGenerator): 449 def __init__(self): 450 TypeGenerator.__init__(self) 451 self.generators = [] 452 self.bounds = [] 453 self.setCardinality() 454 self._cardinality = None 455 456 def getCardinality(self): 457 if self._cardinality is None: 458 return aleph0 459 else: 460 return self._cardinality 461 462 def setCardinality(self): 463 self.bounds = [g.cardinality for g in self.generators] 464 self._cardinality = sum(self.bounds) 465 466 cardinality = property(getCardinality, None) 467 468 def addGenerator(self, g): 469 self.generators.append(g) 470 for i in range(100): 471 prev = self._cardinality 472 self._cardinality = None 473 for g in self.generators: 474 g.setCardinality() 475 self.setCardinality() 476 if (self._cardinality is aleph0) or prev == self._cardinality: 477 break 478 else: 479 raise RuntimeError("Infinite loop in setting cardinality") 480 481 def generateType(self, N): 482 index, M = getNthPairVariableBounds(N, self.bounds) 483 return self.generators[index].get(M) 484 485 486def test(): 487 fbtg = FixedTypeGenerator( 488 [BuiltinType("char", 4), BuiltinType("char", 4, 0), BuiltinType("int", 4, 5)] 489 ) 490 491 fields1 = AnyTypeGenerator() 492 fields1.addGenerator(fbtg) 493 494 fields0 = AnyTypeGenerator() 495 fields0.addGenerator(fbtg) 496 # fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) ) 497 498 btg = FixedTypeGenerator([BuiltinType("char", 4), BuiltinType("int", 4)]) 499 etg = EnumTypeGenerator([None, "-1", "1", "1u"], 0, 3) 500 501 atg = AnyTypeGenerator() 502 atg.addGenerator(btg) 503 atg.addGenerator(RecordTypeGenerator(fields0, False, 4)) 504 atg.addGenerator(etg) 505 print("Cardinality:", atg.cardinality) 506 for i in range(100): 507 if i == atg.cardinality: 508 try: 509 atg.get(i) 510 raise RuntimeError("Cardinality was wrong") 511 except AssertionError: 512 break 513 print("%4d: %s" % (i, atg.get(i))) 514 515 516if __name__ == "__main__": 517 test() 518