"""Flexible enumeration of C types.""" from __future__ import division, print_function from Enumeration import * # TODO: # - struct improvements (flexible arrays, packed & # unpacked, alignment) # - objective-c qualified id # - anonymous / transparent unions # - VLAs # - block types # - K&R functions # - pass arguments of different types (test extension, transparent union) # - varargs ### # Actual type types class Type(object): def isBitField(self): return False def isPaddingBitField(self): return False def getTypeName(self, printer): name = "T%d" % len(printer.types) typedef = self.getTypedefDef(name, printer) printer.addDeclaration(typedef) return name class BuiltinType(Type): def __init__(self, name, size, bitFieldSize=None): self.name = name self.size = size self.bitFieldSize = bitFieldSize def isBitField(self): return self.bitFieldSize is not None def isPaddingBitField(self): return self.bitFieldSize is 0 def getBitFieldSize(self): assert self.isBitField() return self.bitFieldSize def getTypeName(self, printer): return self.name def sizeof(self): return self.size def __str__(self): return self.name class EnumType(Type): unique_id = 0 def __init__(self, index, enumerators): self.index = index self.enumerators = enumerators self.unique_id = self.__class__.unique_id self.__class__.unique_id += 1 def getEnumerators(self): result = "" for i, init in enumerate(self.enumerators): if i > 0: result = result + ", " result = result + "enum%dval%d_%d" % (self.index, i, self.unique_id) if init: result = result + " = %s" % (init) return result def __str__(self): return "enum { %s }" % (self.getEnumerators()) def getTypedefDef(self, name, printer): return "typedef enum %s { %s } %s;" % (name, self.getEnumerators(), name) class RecordType(Type): def __init__(self, index, isUnion, fields): self.index = index self.isUnion = isUnion self.fields = fields self.name = None def __str__(self): def getField(t): if t.isBitField(): return "%s : %d;" % (t, t.getBitFieldSize()) else: return "%s;" % t return "%s { %s }" % ( ("struct", "union")[self.isUnion], " ".join(map(getField, self.fields)), ) def getTypedefDef(self, name, printer): def getField(it): i, t = it if t.isBitField(): if t.isPaddingBitField(): return "%s : 0;" % (printer.getTypeName(t),) else: return "%s field%d : %d;" % ( printer.getTypeName(t), i, t.getBitFieldSize(), ) else: return "%s field%d;" % (printer.getTypeName(t), i) fields = [getField(f) for f in enumerate(self.fields)] # Name the struct for more readable LLVM IR. return "typedef %s %s { %s } %s;" % ( ("struct", "union")[self.isUnion], name, " ".join(fields), name, ) class ArrayType(Type): def __init__(self, index, isVector, elementType, size): if isVector: # Note that for vectors, this is the size in bytes. assert size > 0 else: assert size is None or size >= 0 self.index = index self.isVector = isVector self.elementType = elementType self.size = size if isVector: eltSize = self.elementType.sizeof() assert not (self.size % eltSize) self.numElements = self.size // eltSize else: self.numElements = self.size def __str__(self): if self.isVector: return "vector (%s)[%d]" % (self.elementType, self.size) elif self.size is not None: return "(%s)[%d]" % (self.elementType, self.size) else: return "(%s)[]" % (self.elementType,) def getTypedefDef(self, name, printer): elementName = printer.getTypeName(self.elementType) if self.isVector: return "typedef %s %s __attribute__ ((vector_size (%d)));" % ( elementName, name, self.size, ) else: if self.size is None: sizeStr = "" else: sizeStr = str(self.size) return "typedef %s %s[%s];" % (elementName, name, sizeStr) class ComplexType(Type): def __init__(self, index, elementType): self.index = index self.elementType = elementType def __str__(self): return "_Complex (%s)" % (self.elementType) def getTypedefDef(self, name, printer): return "typedef _Complex %s %s;" % (printer.getTypeName(self.elementType), name) class FunctionType(Type): def __init__(self, index, returnType, argTypes): self.index = index self.returnType = returnType self.argTypes = argTypes def __str__(self): if self.returnType is None: rt = "void" else: rt = str(self.returnType) if not self.argTypes: at = "void" else: at = ", ".join(map(str, self.argTypes)) return "%s (*)(%s)" % (rt, at) def getTypedefDef(self, name, printer): if self.returnType is None: rt = "void" else: rt = str(self.returnType) if not self.argTypes: at = "void" else: at = ", ".join(map(str, self.argTypes)) return "typedef %s (*%s)(%s);" % (rt, name, at) ### # Type enumerators class TypeGenerator(object): def __init__(self): self.cache = {} def setCardinality(self): abstract def get(self, N): T = self.cache.get(N) if T is None: assert 0 <= N < self.cardinality T = self.cache[N] = self.generateType(N) return T def generateType(self, N): abstract class FixedTypeGenerator(TypeGenerator): def __init__(self, types): TypeGenerator.__init__(self) self.types = types self.setCardinality() def setCardinality(self): self.cardinality = len(self.types) def generateType(self, N): return self.types[N] # Factorial def fact(n): result = 1 while n > 0: result = result * n n = n - 1 return result # Compute the number of combinations (n choose k) def num_combinations(n, k): return fact(n) // (fact(k) * fact(n - k)) # Enumerate the combinations choosing k elements from the list of values def combinations(values, k): # From ActiveState Recipe 190465: Generator for permutations, # combinations, selections of a sequence if k == 0: yield [] else: for i in range(len(values) - k + 1): for cc in combinations(values[i + 1 :], k - 1): yield [values[i]] + cc class EnumTypeGenerator(TypeGenerator): def __init__(self, values, minEnumerators, maxEnumerators): TypeGenerator.__init__(self) self.values = values self.minEnumerators = minEnumerators self.maxEnumerators = maxEnumerators self.setCardinality() def setCardinality(self): self.cardinality = 0 for num in range(self.minEnumerators, self.maxEnumerators + 1): self.cardinality += num_combinations(len(self.values), num) def generateType(self, n): # Figure out the number of enumerators in this type numEnumerators = self.minEnumerators valuesCovered = 0 while numEnumerators < self.maxEnumerators: comb = num_combinations(len(self.values), numEnumerators) if valuesCovered + comb > n: break numEnumerators = numEnumerators + 1 valuesCovered += comb # Find the requested combination of enumerators and build a # type from it. i = 0 for enumerators in combinations(self.values, numEnumerators): if i == n - valuesCovered: return EnumType(n, enumerators) i = i + 1 assert False class ComplexTypeGenerator(TypeGenerator): def __init__(self, typeGen): TypeGenerator.__init__(self) self.typeGen = typeGen self.setCardinality() def setCardinality(self): self.cardinality = self.typeGen.cardinality def generateType(self, N): return ComplexType(N, self.typeGen.get(N)) class VectorTypeGenerator(TypeGenerator): def __init__(self, typeGen, sizes): TypeGenerator.__init__(self) self.typeGen = typeGen self.sizes = tuple(map(int, sizes)) self.setCardinality() def setCardinality(self): self.cardinality = len(self.sizes) * self.typeGen.cardinality def generateType(self, N): S, T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality) return ArrayType(N, True, self.typeGen.get(T), self.sizes[S]) class FixedArrayTypeGenerator(TypeGenerator): def __init__(self, typeGen, sizes): TypeGenerator.__init__(self) self.typeGen = typeGen self.sizes = tuple(size) self.setCardinality() def setCardinality(self): self.cardinality = len(self.sizes) * self.typeGen.cardinality def generateType(self, N): S, T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality) return ArrayType(N, false, self.typeGen.get(T), self.sizes[S]) class ArrayTypeGenerator(TypeGenerator): def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False): TypeGenerator.__init__(self) self.typeGen = typeGen self.useIncomplete = useIncomplete self.useZero = useZero self.maxSize = int(maxSize) self.W = useIncomplete + useZero + self.maxSize self.setCardinality() def setCardinality(self): self.cardinality = self.W * self.typeGen.cardinality def generateType(self, N): S, T = getNthPairBounded(N, self.W, self.typeGen.cardinality) if self.useIncomplete: if S == 0: size = None S = None else: S = S - 1 if S is not None: if self.useZero: size = S else: size = S + 1 return ArrayType(N, False, self.typeGen.get(T), size) class RecordTypeGenerator(TypeGenerator): def __init__(self, typeGen, useUnion, maxSize): TypeGenerator.__init__(self) self.typeGen = typeGen self.useUnion = bool(useUnion) self.maxSize = int(maxSize) self.setCardinality() def setCardinality(self): M = 1 + self.useUnion if self.maxSize is aleph0: S = aleph0 * self.typeGen.cardinality else: S = 0 for i in range(self.maxSize + 1): S += M * (self.typeGen.cardinality**i) self.cardinality = S def generateType(self, N): isUnion, I = False, N if self.useUnion: isUnion, I = (I & 1), I >> 1 fields = [ self.typeGen.get(f) for f in getNthTuple(I, self.maxSize, self.typeGen.cardinality) ] return RecordType(N, isUnion, fields) class FunctionTypeGenerator(TypeGenerator): def __init__(self, typeGen, useReturn, maxSize): TypeGenerator.__init__(self) self.typeGen = typeGen self.useReturn = useReturn self.maxSize = maxSize self.setCardinality() def setCardinality(self): if self.maxSize is aleph0: S = aleph0 * self.typeGen.cardinality() elif self.useReturn: S = 0 for i in range(1, self.maxSize + 1 + 1): S += self.typeGen.cardinality**i else: S = 0 for i in range(self.maxSize + 1): S += self.typeGen.cardinality**i self.cardinality = S def generateType(self, N): if self.useReturn: # Skip the empty tuple argIndices = getNthTuple(N + 1, self.maxSize + 1, self.typeGen.cardinality) retIndex, argIndices = argIndices[0], argIndices[1:] retTy = self.typeGen.get(retIndex) else: retTy = None argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality) args = [self.typeGen.get(i) for i in argIndices] return FunctionType(N, retTy, args) class AnyTypeGenerator(TypeGenerator): def __init__(self): TypeGenerator.__init__(self) self.generators = [] self.bounds = [] self.setCardinality() self._cardinality = None def getCardinality(self): if self._cardinality is None: return aleph0 else: return self._cardinality def setCardinality(self): self.bounds = [g.cardinality for g in self.generators] self._cardinality = sum(self.bounds) cardinality = property(getCardinality, None) def addGenerator(self, g): self.generators.append(g) for i in range(100): prev = self._cardinality self._cardinality = None for g in self.generators: g.setCardinality() self.setCardinality() if (self._cardinality is aleph0) or prev == self._cardinality: break else: raise RuntimeError("Infinite loop in setting cardinality") def generateType(self, N): index, M = getNthPairVariableBounds(N, self.bounds) return self.generators[index].get(M) def test(): fbtg = FixedTypeGenerator( [BuiltinType("char", 4), BuiltinType("char", 4, 0), BuiltinType("int", 4, 5)] ) fields1 = AnyTypeGenerator() fields1.addGenerator(fbtg) fields0 = AnyTypeGenerator() fields0.addGenerator(fbtg) # fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) ) btg = FixedTypeGenerator([BuiltinType("char", 4), BuiltinType("int", 4)]) etg = EnumTypeGenerator([None, "-1", "1", "1u"], 0, 3) atg = AnyTypeGenerator() atg.addGenerator(btg) atg.addGenerator(RecordTypeGenerator(fields0, False, 4)) atg.addGenerator(etg) print("Cardinality:", atg.cardinality) for i in range(100): if i == atg.cardinality: try: atg.get(i) raise RuntimeError("Cardinality was wrong") except AssertionError: break print("%4d: %s" % (i, atg.get(i))) if __name__ == "__main__": test()