xref: /llvm-project/clang/utils/ABITest/TypeGen.py (revision dd3c26a045c081620375a878159f536758baba6e)
1"""Flexible enumeration of C types."""
2from __future__ import division, print_function
3
4from Enumeration import *
5
6# TODO:
7
8#  - struct improvements (flexible arrays, packed &
9#    unpacked, alignment)
10#  - objective-c qualified id
11#  - anonymous / transparent unions
12#  - VLAs
13#  - block types
14#  - K&R functions
15#  - pass arguments of different types (test extension, transparent union)
16#  - varargs
17
18###
19# Actual type types
20
21
22class Type(object):
23    def isBitField(self):
24        return False
25
26    def isPaddingBitField(self):
27        return False
28
29    def getTypeName(self, printer):
30        name = "T%d" % len(printer.types)
31        typedef = self.getTypedefDef(name, printer)
32        printer.addDeclaration(typedef)
33        return name
34
35
36class BuiltinType(Type):
37    def __init__(self, name, size, bitFieldSize=None):
38        self.name = name
39        self.size = size
40        self.bitFieldSize = bitFieldSize
41
42    def isBitField(self):
43        return self.bitFieldSize is not None
44
45    def isPaddingBitField(self):
46        return self.bitFieldSize is 0
47
48    def getBitFieldSize(self):
49        assert self.isBitField()
50        return self.bitFieldSize
51
52    def getTypeName(self, printer):
53        return self.name
54
55    def sizeof(self):
56        return self.size
57
58    def __str__(self):
59        return self.name
60
61
62class EnumType(Type):
63    unique_id = 0
64
65    def __init__(self, index, enumerators):
66        self.index = index
67        self.enumerators = enumerators
68        self.unique_id = self.__class__.unique_id
69        self.__class__.unique_id += 1
70
71    def getEnumerators(self):
72        result = ""
73        for i, init in enumerate(self.enumerators):
74            if i > 0:
75                result = result + ", "
76            result = result + "enum%dval%d_%d" % (self.index, i, self.unique_id)
77            if init:
78                result = result + " = %s" % (init)
79
80        return result
81
82    def __str__(self):
83        return "enum { %s }" % (self.getEnumerators())
84
85    def getTypedefDef(self, name, printer):
86        return "typedef enum %s { %s } %s;" % (name, self.getEnumerators(), name)
87
88
89class RecordType(Type):
90    def __init__(self, index, isUnion, fields):
91        self.index = index
92        self.isUnion = isUnion
93        self.fields = fields
94        self.name = None
95
96    def __str__(self):
97        def getField(t):
98            if t.isBitField():
99                return "%s : %d;" % (t, t.getBitFieldSize())
100            else:
101                return "%s;" % t
102
103        return "%s { %s }" % (
104            ("struct", "union")[self.isUnion],
105            " ".join(map(getField, self.fields)),
106        )
107
108    def getTypedefDef(self, name, printer):
109        def getField(it):
110            i, t = it
111            if t.isBitField():
112                if t.isPaddingBitField():
113                    return "%s : 0;" % (printer.getTypeName(t),)
114                else:
115                    return "%s field%d : %d;" % (
116                        printer.getTypeName(t),
117                        i,
118                        t.getBitFieldSize(),
119                    )
120            else:
121                return "%s field%d;" % (printer.getTypeName(t), i)
122
123        fields = [getField(f) for f in enumerate(self.fields)]
124        # Name the struct for more readable LLVM IR.
125        return "typedef %s %s { %s } %s;" % (
126            ("struct", "union")[self.isUnion],
127            name,
128            " ".join(fields),
129            name,
130        )
131
132
133class ArrayType(Type):
134    def __init__(self, index, isVector, elementType, size):
135        if isVector:
136            # Note that for vectors, this is the size in bytes.
137            assert size > 0
138        else:
139            assert size is None or size >= 0
140        self.index = index
141        self.isVector = isVector
142        self.elementType = elementType
143        self.size = size
144        if isVector:
145            eltSize = self.elementType.sizeof()
146            assert not (self.size % eltSize)
147            self.numElements = self.size // eltSize
148        else:
149            self.numElements = self.size
150
151    def __str__(self):
152        if self.isVector:
153            return "vector (%s)[%d]" % (self.elementType, self.size)
154        elif self.size is not None:
155            return "(%s)[%d]" % (self.elementType, self.size)
156        else:
157            return "(%s)[]" % (self.elementType,)
158
159    def getTypedefDef(self, name, printer):
160        elementName = printer.getTypeName(self.elementType)
161        if self.isVector:
162            return "typedef %s %s __attribute__ ((vector_size (%d)));" % (
163                elementName,
164                name,
165                self.size,
166            )
167        else:
168            if self.size is None:
169                sizeStr = ""
170            else:
171                sizeStr = str(self.size)
172            return "typedef %s %s[%s];" % (elementName, name, sizeStr)
173
174
175class ComplexType(Type):
176    def __init__(self, index, elementType):
177        self.index = index
178        self.elementType = elementType
179
180    def __str__(self):
181        return "_Complex (%s)" % (self.elementType)
182
183    def getTypedefDef(self, name, printer):
184        return "typedef _Complex %s %s;" % (printer.getTypeName(self.elementType), name)
185
186
187class FunctionType(Type):
188    def __init__(self, index, returnType, argTypes):
189        self.index = index
190        self.returnType = returnType
191        self.argTypes = argTypes
192
193    def __str__(self):
194        if self.returnType is None:
195            rt = "void"
196        else:
197            rt = str(self.returnType)
198        if not self.argTypes:
199            at = "void"
200        else:
201            at = ", ".join(map(str, self.argTypes))
202        return "%s (*)(%s)" % (rt, at)
203
204    def getTypedefDef(self, name, printer):
205        if self.returnType is None:
206            rt = "void"
207        else:
208            rt = str(self.returnType)
209        if not self.argTypes:
210            at = "void"
211        else:
212            at = ", ".join(map(str, self.argTypes))
213        return "typedef %s (*%s)(%s);" % (rt, name, at)
214
215
216###
217# Type enumerators
218
219
220class TypeGenerator(object):
221    def __init__(self):
222        self.cache = {}
223
224    def setCardinality(self):
225        abstract
226
227    def get(self, N):
228        T = self.cache.get(N)
229        if T is None:
230            assert 0 <= N < self.cardinality
231            T = self.cache[N] = self.generateType(N)
232        return T
233
234    def generateType(self, N):
235        abstract
236
237
238class FixedTypeGenerator(TypeGenerator):
239    def __init__(self, types):
240        TypeGenerator.__init__(self)
241        self.types = types
242        self.setCardinality()
243
244    def setCardinality(self):
245        self.cardinality = len(self.types)
246
247    def generateType(self, N):
248        return self.types[N]
249
250
251# Factorial
252def fact(n):
253    result = 1
254    while n > 0:
255        result = result * n
256        n = n - 1
257    return result
258
259
260# Compute the number of combinations (n choose k)
261def num_combinations(n, k):
262    return fact(n) // (fact(k) * fact(n - k))
263
264
265# Enumerate the combinations choosing k elements from the list of values
266def combinations(values, k):
267    # From ActiveState Recipe 190465: Generator for permutations,
268    # combinations, selections of a sequence
269    if k == 0:
270        yield []
271    else:
272        for i in range(len(values) - k + 1):
273            for cc in combinations(values[i + 1 :], k - 1):
274                yield [values[i]] + cc
275
276
277class EnumTypeGenerator(TypeGenerator):
278    def __init__(self, values, minEnumerators, maxEnumerators):
279        TypeGenerator.__init__(self)
280        self.values = values
281        self.minEnumerators = minEnumerators
282        self.maxEnumerators = maxEnumerators
283        self.setCardinality()
284
285    def setCardinality(self):
286        self.cardinality = 0
287        for num in range(self.minEnumerators, self.maxEnumerators + 1):
288            self.cardinality += num_combinations(len(self.values), num)
289
290    def generateType(self, n):
291        # Figure out the number of enumerators in this type
292        numEnumerators = self.minEnumerators
293        valuesCovered = 0
294        while numEnumerators < self.maxEnumerators:
295            comb = num_combinations(len(self.values), numEnumerators)
296            if valuesCovered + comb > n:
297                break
298            numEnumerators = numEnumerators + 1
299            valuesCovered += comb
300
301        # Find the requested combination of enumerators and build a
302        # type from it.
303        i = 0
304        for enumerators in combinations(self.values, numEnumerators):
305            if i == n - valuesCovered:
306                return EnumType(n, enumerators)
307
308            i = i + 1
309
310        assert False
311
312
313class ComplexTypeGenerator(TypeGenerator):
314    def __init__(self, typeGen):
315        TypeGenerator.__init__(self)
316        self.typeGen = typeGen
317        self.setCardinality()
318
319    def setCardinality(self):
320        self.cardinality = self.typeGen.cardinality
321
322    def generateType(self, N):
323        return ComplexType(N, self.typeGen.get(N))
324
325
326class VectorTypeGenerator(TypeGenerator):
327    def __init__(self, typeGen, sizes):
328        TypeGenerator.__init__(self)
329        self.typeGen = typeGen
330        self.sizes = tuple(map(int, sizes))
331        self.setCardinality()
332
333    def setCardinality(self):
334        self.cardinality = len(self.sizes) * self.typeGen.cardinality
335
336    def generateType(self, N):
337        S, T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
338        return ArrayType(N, True, self.typeGen.get(T), self.sizes[S])
339
340
341class FixedArrayTypeGenerator(TypeGenerator):
342    def __init__(self, typeGen, sizes):
343        TypeGenerator.__init__(self)
344        self.typeGen = typeGen
345        self.sizes = tuple(size)
346        self.setCardinality()
347
348    def setCardinality(self):
349        self.cardinality = len(self.sizes) * self.typeGen.cardinality
350
351    def generateType(self, N):
352        S, T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
353        return ArrayType(N, false, self.typeGen.get(T), self.sizes[S])
354
355
356class ArrayTypeGenerator(TypeGenerator):
357    def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False):
358        TypeGenerator.__init__(self)
359        self.typeGen = typeGen
360        self.useIncomplete = useIncomplete
361        self.useZero = useZero
362        self.maxSize = int(maxSize)
363        self.W = useIncomplete + useZero + self.maxSize
364        self.setCardinality()
365
366    def setCardinality(self):
367        self.cardinality = self.W * self.typeGen.cardinality
368
369    def generateType(self, N):
370        S, T = getNthPairBounded(N, self.W, self.typeGen.cardinality)
371        if self.useIncomplete:
372            if S == 0:
373                size = None
374                S = None
375            else:
376                S = S - 1
377        if S is not None:
378            if self.useZero:
379                size = S
380            else:
381                size = S + 1
382        return ArrayType(N, False, self.typeGen.get(T), size)
383
384
385class RecordTypeGenerator(TypeGenerator):
386    def __init__(self, typeGen, useUnion, maxSize):
387        TypeGenerator.__init__(self)
388        self.typeGen = typeGen
389        self.useUnion = bool(useUnion)
390        self.maxSize = int(maxSize)
391        self.setCardinality()
392
393    def setCardinality(self):
394        M = 1 + self.useUnion
395        if self.maxSize is aleph0:
396            S = aleph0 * self.typeGen.cardinality
397        else:
398            S = 0
399            for i in range(self.maxSize + 1):
400                S += M * (self.typeGen.cardinality**i)
401        self.cardinality = S
402
403    def generateType(self, N):
404        isUnion, I = False, N
405        if self.useUnion:
406            isUnion, I = (I & 1), I >> 1
407        fields = [
408            self.typeGen.get(f)
409            for f in getNthTuple(I, self.maxSize, self.typeGen.cardinality)
410        ]
411        return RecordType(N, isUnion, fields)
412
413
414class FunctionTypeGenerator(TypeGenerator):
415    def __init__(self, typeGen, useReturn, maxSize):
416        TypeGenerator.__init__(self)
417        self.typeGen = typeGen
418        self.useReturn = useReturn
419        self.maxSize = maxSize
420        self.setCardinality()
421
422    def setCardinality(self):
423        if self.maxSize is aleph0:
424            S = aleph0 * self.typeGen.cardinality()
425        elif self.useReturn:
426            S = 0
427            for i in range(1, self.maxSize + 1 + 1):
428                S += self.typeGen.cardinality**i
429        else:
430            S = 0
431            for i in range(self.maxSize + 1):
432                S += self.typeGen.cardinality**i
433        self.cardinality = S
434
435    def generateType(self, N):
436        if self.useReturn:
437            # Skip the empty tuple
438            argIndices = getNthTuple(N + 1, self.maxSize + 1, self.typeGen.cardinality)
439            retIndex, argIndices = argIndices[0], argIndices[1:]
440            retTy = self.typeGen.get(retIndex)
441        else:
442            retTy = None
443            argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality)
444        args = [self.typeGen.get(i) for i in argIndices]
445        return FunctionType(N, retTy, args)
446
447
448class AnyTypeGenerator(TypeGenerator):
449    def __init__(self):
450        TypeGenerator.__init__(self)
451        self.generators = []
452        self.bounds = []
453        self.setCardinality()
454        self._cardinality = None
455
456    def getCardinality(self):
457        if self._cardinality is None:
458            return aleph0
459        else:
460            return self._cardinality
461
462    def setCardinality(self):
463        self.bounds = [g.cardinality for g in self.generators]
464        self._cardinality = sum(self.bounds)
465
466    cardinality = property(getCardinality, None)
467
468    def addGenerator(self, g):
469        self.generators.append(g)
470        for i in range(100):
471            prev = self._cardinality
472            self._cardinality = None
473            for g in self.generators:
474                g.setCardinality()
475            self.setCardinality()
476            if (self._cardinality is aleph0) or prev == self._cardinality:
477                break
478        else:
479            raise RuntimeError("Infinite loop in setting cardinality")
480
481    def generateType(self, N):
482        index, M = getNthPairVariableBounds(N, self.bounds)
483        return self.generators[index].get(M)
484
485
486def test():
487    fbtg = FixedTypeGenerator(
488        [BuiltinType("char", 4), BuiltinType("char", 4, 0), BuiltinType("int", 4, 5)]
489    )
490
491    fields1 = AnyTypeGenerator()
492    fields1.addGenerator(fbtg)
493
494    fields0 = AnyTypeGenerator()
495    fields0.addGenerator(fbtg)
496    #    fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) )
497
498    btg = FixedTypeGenerator([BuiltinType("char", 4), BuiltinType("int", 4)])
499    etg = EnumTypeGenerator([None, "-1", "1", "1u"], 0, 3)
500
501    atg = AnyTypeGenerator()
502    atg.addGenerator(btg)
503    atg.addGenerator(RecordTypeGenerator(fields0, False, 4))
504    atg.addGenerator(etg)
505    print("Cardinality:", atg.cardinality)
506    for i in range(100):
507        if i == atg.cardinality:
508            try:
509                atg.get(i)
510                raise RuntimeError("Cardinality was wrong")
511            except AssertionError:
512                break
513        print("%4d: %s" % (i, atg.get(i)))
514
515
516if __name__ == "__main__":
517    test()
518