xref: /llvm-project/polly/lib/External/isl/imath/tests/gmp-compat-test/gendata.py (revision f98ee40f4b5d7474fc67e82824bf6abbaedb7b1c)
1#!/usr/bin/env python
2import random
3import gmpapi
4
5MAX_SLONG = "9223372036854775807"
6MIN_SLONG = "-9223372036854775808"
7MAX_ULONG = "18446744073709551615"
8MAX_SINT = "2147483647"
9MIN_SINT = "-2147483648"
10MAX_UINT = "4294967295"
11MAX_SSHORT = "32767"
12MIN_SSHORT = "-32768"
13MAX_USHORT = "65535"
14
15
16def plus1(x):
17    return str(int(x) + 1)
18
19
20def minus1(x):
21    return str(int(x) - 1)
22
23
24def apply(fun, lst):
25    return list(map(str, map(fun, lst)))
26
27
28mzero_one = ["-0", "-1"]
29zero_one = ["0", "1"]
30mm_slong = [MAX_SLONG, MIN_SLONG]
31mm_slong1 = [minus1(MAX_SLONG), plus1(MIN_SLONG)]
32mm_ulong = [MAX_ULONG]
33mm_ulong1 = [minus1(MAX_ULONG)]
34mm_sint = [MAX_SINT, MIN_SINT]
35mm_sint1 = [minus1(MAX_SINT), plus1(MIN_SINT)]
36mm_uint = [MAX_UINT]
37mm_uint1 = [minus1(MAX_UINT)]
38mm_sshort = [MAX_SSHORT, MIN_SSHORT]
39mm_sshort1 = [minus1(MAX_SSHORT), plus1(MIN_SSHORT)]
40mm_ushort = [MAX_USHORT]
41mm_ushort1 = [minus1(MAX_USHORT)]
42mm_all = mm_slong + mm_ulong + mm_sint + mm_uint + mm_sshort + mm_ushort
43zero_one_all = mzero_one + zero_one
44
45mpz_std_list = zero_one_all + mm_all + apply(plus1, mm_all) + apply(minus1, mm_all)
46si_std_list = (
47    zero_one + mm_slong + mm_sint + mm_sshort + mm_slong1 + mm_sint1 + mm_sshort1
48)
49ui_std_list = (
50    zero_one + mm_ulong + mm_uint + mm_ushort + mm_ulong1 + mm_uint1 + mm_ushort1
51)
52
53
54def gen_random_mpz(mindigits=1, maxdigits=100, allowneg=True):
55    sign = random.choice(["", "-"])
56    if not allowneg:
57        sign = ""
58    return sign + gen_digits(random.randint(mindigits, maxdigits))
59
60
61def gen_random_si():
62    si = gen_random_mpz(mindigits=1, maxdigits=19)
63    while int(si) > int(MAX_SLONG) or int(si) < int(MIN_SLONG):
64        si = gen_random_mpz(mindigits=1, maxdigits=19)
65    return si
66
67
68def gen_random_ui():
69    ui = gen_random_mpz(mindigits=1, maxdigits=20, allowneg=False)
70    while int(ui) > int(MAX_ULONG):
71        ui = gen_random_mpz(mindigits=1, maxdigits=20, allowneg=False)
72    return ui
73
74
75def gen_digits(length):
76    if length == 1:
77        i = random.randint(1, 9)
78    else:
79        digits = [random.randint(1, 9)] + [
80            random.randint(0, 9) for x in range(length - 1)
81        ]
82        digits = map(str, digits)
83        i = "".join(digits)
84    return str(i)
85
86
87def gen_mpzs(mindigits=1, maxdigits=100, count=10):
88    return [
89        gen_random_mpz(mindigits=mindigits, maxdigits=maxdigits) for x in range(count)
90    ]
91
92
93default_count = 10
94
95
96def gen_sis(count=default_count):
97    return [gen_random_si() for x in range(count)]
98
99
100def gen_uis(count=default_count):
101    return [gen_random_ui() for x in range(count)]
102
103
104def gen_small_mpzs(count=default_count):
105    return gen_mpzs(mindigits=1, maxdigits=4, count=count)
106
107
108def is_small_mpz(s):
109    return len(s) >= 1 and len(s) <= 4
110
111
112def gen_medium_mpzs(count=default_count):
113    return gen_mpzs(mindigits=5, maxdigits=20, count=count)
114
115
116def is_medium_mpz(s):
117    return len(s) >= 5 and len(s) <= 20
118
119
120def gen_large_mpzs(count=default_count):
121    return gen_mpzs(mindigits=21, maxdigits=100, count=count)
122
123
124def is_large_mpz(s):
125    return len(s) >= 21
126
127
128def gen_mpz_spread(count=default_count):
129    return gen_small_mpzs(count) + gen_medium_mpzs(count) + gen_large_mpzs(count)
130
131
132def gen_mpz_args(count=default_count):
133    return mpz_std_list + gen_mpz_spread(count)
134
135
136def gen_mpq_args(count=4):
137    nums = zero_one + gen_mpz_spread(count)
138    dens = ["1"] + gen_mpz_spread(count)
139    return [n + "/" + d for n in nums for d in dens if int(d) != 0]
140
141
142def gen_si_args():
143    return si_std_list + gen_sis()
144
145
146def gen_ui_args():
147    return ui_std_list + gen_uis()
148
149
150def gen_list_for_type(t, is_write_only):
151    if (t == gmpapi.mpz_t or t == gmpapi.mpq_t) and is_write_only:
152        return ["0"]
153    elif t == gmpapi.mpz_t:
154        return gen_mpz_args()
155    elif t == gmpapi.ilong:
156        return gen_si_args()
157    elif t == gmpapi.ulong:
158        return gen_ui_args()
159    elif t == gmpapi.mpq_t:
160        return gen_mpq_args()
161    else:
162        raise RuntimeError("Unknown type: {}".format(t))
163
164
165def gen_args(api):
166    if api.custom_test or api.name in custom:
167        return custom[api.name](api)
168    types = api.params
169    if len(types) == 1:
170        return [[a] for a in gen_list_for_type(types[0], api.is_write_only(0))]
171    elif len(types) == 2:
172        t1 = gen_list_for_type(types[0], api.is_write_only(0))
173        t2 = gen_list_for_type(types[1], api.is_write_only(1))
174        return [(a, b) for a in t1 for b in t2]
175    elif len(types) == 3:
176        t1 = gen_list_for_type(types[0], api.is_write_only(0))
177        t2 = gen_list_for_type(types[1], api.is_write_only(1))
178        t3 = gen_list_for_type(types[2], api.is_write_only(2))
179        return [(a, b, c) for a in t1 for b in t2 for c in t3]
180    elif len(types) == 4:
181        t1 = gen_list_for_type(types[0], api.is_write_only(0))
182        t2 = gen_list_for_type(types[1], api.is_write_only(1))
183        t3 = gen_list_for_type(types[2], api.is_write_only(2))
184        t4 = gen_list_for_type(types[3], api.is_write_only(3))
185        return [(a, b, c, d) for a in t1 for b in t2 for c in t3 for d in t4]
186    else:
187        raise RuntimeError("Too many args: {}".format(len(types)))
188
189
190###################################################################
191#
192# Fixup and massage random data for better test coverage
193#
194###################################################################
195def mul_mpzs(a, b):
196    return str(int(a) * int(b))
197
198
199def mpz_divexact_data(args):
200    # set n = n * d
201    divisible = mul_mpzs(args[1], (args[2]))
202    return [(args[0], divisible, args[2])]
203
204
205def mpz_divisible_p_data(args):
206    (n, d) = get_div_data(args[0], args[1], rate=1.0)
207    return [(n, d), (args[0], args[1])]
208
209
210def mpz_div3_data(args):
211    q = args[0]
212    (n, d) = get_div_data(args[1], args[2], rate=1.0)
213    return [(q, n, d), (q, args[1], args[2])]
214
215
216def mpz_pow_data(args, alwaysallowbase1=True):
217    base = int(args[1])
218    exp = int(args[2])
219    # allow special numbers
220    if base == 0 or exp == 0 or exp == 1:
221        return [args]
222    if base == 1 and alwaysallowbase1:
223        return [args]
224
225    # disallow too big numbers
226    if base > 1000 or base < -1000:
227        base = gen_random_mpz(maxdigits=3)
228    if exp > 1000:
229        exp = gen_random_mpz(maxdigits=3, allowneg=False)
230
231    return [(args[0], str(base), str(exp))]
232
233
234def mpz_mul_2exp_data(args):
235    return mpz_pow_data(args, alwaysallowbase1=False)
236
237
238def mpz_gcd_data(args):
239    r = args[0]
240    a = args[1]
241    b = args[2]
242    s_ = gen_small_mpzs(1)[0]
243    m_ = gen_medium_mpzs(1)[0]
244    l_ = gen_large_mpzs(1)[0]
245
246    return [
247        (r, a, b),
248        (r, mul_mpzs(a, b), b),
249        (r, mul_mpzs(a, s_), mul_mpzs(b, s_)),
250        (r, mul_mpzs(a, m_), mul_mpzs(b, m_)),
251        (r, mul_mpzs(a, l_), mul_mpzs(b, l_)),
252    ]
253
254
255def mpz_export_data(api):
256    rop = ["0"]
257    countp = ["0"]
258    order = ["-1", "1"]
259    size = ["1", "2", "4", "8"]
260    endian = ["0"]
261    nails = ["0"]
262    ops = gen_mpz_args(1000) + gen_mpzs(count=100, mindigits=100, maxdigits=1000)
263
264    args = []
265    for r in rop:
266        for c in countp:
267            for o in order:
268                for s in size:
269                    for e in endian:
270                        for n in nails:
271                            for op in ops:
272                                args.append((r, c, o, s, e, n, op))
273    return args
274
275
276def mpz_sizeinbase_data(api):
277    bases = list(map(str, range(2, 37)))
278    ops = gen_mpz_args(1000) + gen_mpzs(count=1000, mindigits=100, maxdigits=2000)
279    return [(op, b) for op in ops for b in bases]
280
281
282def get_str_data(ty):
283    bases = list(range(2, 37)) + list(range(-2, -37, -1))
284    bases = list(map(str, bases))
285    if ty == gmpapi.mpz_t:
286        ops = gen_mpz_args(1000)
287    elif ty == gmpapi.mpq_t:
288        ops = gen_mpq_args(20)
289    else:
290        raise RuntimeError("Unsupported get_str type: " + str(ty))
291    return [("NULL", b, op) for b in bases for op in ops]
292
293
294def mpz_get_str_data(api):
295    return get_str_data(gmpapi.mpz_t)
296
297
298def mpq_get_str_data(api):
299    return get_str_data(gmpapi.mpq_t)
300
301
302def mpq_set_str_data(api):
303    args = gen_mpq_args(20) + gen_mpz_args()
304    # zero does not match results exactly because the
305    # results are not canonicalized first. We choose to
306    # exclude zero from test results. The other option is
307    # to canonicalize the results after parsing the strings.
308    # Instead we exclude zero so that we can independently
309    # test correctness of set_str and canonicalization
310    nonzero = []
311    for arg in args:
312        if "/" in arg:
313            pos = arg.find("/")
314            if int(arg[:pos]) != 0:
315                nonzero.append(arg)
316        elif int(arg) != 0:
317            nonzero.append(arg)
318
319    return [("0", q, "10") for q in nonzero]
320
321
322def get_div_data(n, d, rate=0.2):
323    """Generate some inputs that are perfectly divisible"""
324    if random.random() < rate:
325        n = mul_mpzs(n, d)
326    return (n, d)
327
328
329def allow(name, args):
330    if name not in blacklists:
331        return True
332    filters = blacklists[name]
333    for (pos, disallow) in filters:
334        if args[pos] in disallow:
335            return False
336    return True
337
338
339def fixup_args(name, args):
340    if name not in fixups:
341        return [args]
342    return fixups[name](args)
343
344
345# list of values to be excluded for various api calls
346# list format is (pos, [list of values to exclude])
347blacklists = {
348    "mpz_cdiv_q": [(2, ["0", "-0"])],
349    "mpz_fdiv_q": [(2, ["0", "-0"])],
350    "mpz_fdiv_r": [(2, ["0", "-0"])],
351    "mpz_tdiv_q": [(2, ["0", "-0"])],
352    "mpz_fdiv_q_ui": [(2, ["0", "-0"])],
353    "mpz_divexact": [(2, ["0", "-0"])],
354    "mpz_divisible_p": [(1, ["0", "-0"])],
355    "mpz_divexact_ui": [(2, ["0", "-0"])],
356    "mpq_set_ui": [(2, ["0", "-0"])],
357}
358
359fixups = {
360    "mpz_divexact": mpz_divexact_data,
361    "mpz_divisible_p": mpz_divisible_p_data,
362    "mpz_cdiv_q": mpz_div3_data,
363    "mpz_fdiv_q": mpz_div3_data,
364    "mpz_fdiv_r": mpz_div3_data,
365    "mpz_tdiv_q": mpz_div3_data,
366    "mpz_fdiv_q_ui": mpz_div3_data,
367    "mpz_divexact_ui": mpz_divexact_data,
368    "mpz_pow_ui": mpz_pow_data,
369    "mpz_gcd": mpz_gcd_data,
370    "mpz_lcm": mpz_gcd_data,
371    "mpz_mul_2exp": mpz_mul_2exp_data,
372}
373
374custom = {
375    "mpz_export": mpz_export_data,
376    "mpz_import": mpz_export_data,
377    "mpz_sizeinbase": mpz_sizeinbase_data,
378    "mpz_get_str": mpz_get_str_data,
379    "mpq_set_str": mpq_set_str_data,
380    "mpq_get_str": mpq_get_str_data,
381}
382
383if __name__ == "__main__":
384    # apis = [gmpapi.get_api("mpq_set_str"),]
385    apis = gmpapi.apis
386    for api in apis:
387        tests = gen_args(api)
388        for args in tests:
389            expanded_args = fixup_args(api.name, args)
390            for args in expanded_args:
391                if allow(api.name, args):
392                    print("{}|{}".format(api.name, ",".join(args)))
393