1#!/usr/bin/env python 2import random 3import gmpapi 4 5MAX_SLONG = "9223372036854775807" 6MIN_SLONG = "-9223372036854775808" 7MAX_ULONG = "18446744073709551615" 8MAX_SINT = "2147483647" 9MIN_SINT = "-2147483648" 10MAX_UINT = "4294967295" 11MAX_SSHORT = "32767" 12MIN_SSHORT = "-32768" 13MAX_USHORT = "65535" 14 15 16def plus1(x): 17 return str(int(x) + 1) 18 19 20def minus1(x): 21 return str(int(x) - 1) 22 23 24def apply(fun, lst): 25 return list(map(str, map(fun, lst))) 26 27 28mzero_one = ["-0", "-1"] 29zero_one = ["0", "1"] 30mm_slong = [MAX_SLONG, MIN_SLONG] 31mm_slong1 = [minus1(MAX_SLONG), plus1(MIN_SLONG)] 32mm_ulong = [MAX_ULONG] 33mm_ulong1 = [minus1(MAX_ULONG)] 34mm_sint = [MAX_SINT, MIN_SINT] 35mm_sint1 = [minus1(MAX_SINT), plus1(MIN_SINT)] 36mm_uint = [MAX_UINT] 37mm_uint1 = [minus1(MAX_UINT)] 38mm_sshort = [MAX_SSHORT, MIN_SSHORT] 39mm_sshort1 = [minus1(MAX_SSHORT), plus1(MIN_SSHORT)] 40mm_ushort = [MAX_USHORT] 41mm_ushort1 = [minus1(MAX_USHORT)] 42mm_all = mm_slong + mm_ulong + mm_sint + mm_uint + mm_sshort + mm_ushort 43zero_one_all = mzero_one + zero_one 44 45mpz_std_list = zero_one_all + mm_all + apply(plus1, mm_all) + apply(minus1, mm_all) 46si_std_list = ( 47 zero_one + mm_slong + mm_sint + mm_sshort + mm_slong1 + mm_sint1 + mm_sshort1 48) 49ui_std_list = ( 50 zero_one + mm_ulong + mm_uint + mm_ushort + mm_ulong1 + mm_uint1 + mm_ushort1 51) 52 53 54def gen_random_mpz(mindigits=1, maxdigits=100, allowneg=True): 55 sign = random.choice(["", "-"]) 56 if not allowneg: 57 sign = "" 58 return sign + gen_digits(random.randint(mindigits, maxdigits)) 59 60 61def gen_random_si(): 62 si = gen_random_mpz(mindigits=1, maxdigits=19) 63 while int(si) > int(MAX_SLONG) or int(si) < int(MIN_SLONG): 64 si = gen_random_mpz(mindigits=1, maxdigits=19) 65 return si 66 67 68def gen_random_ui(): 69 ui = gen_random_mpz(mindigits=1, maxdigits=20, allowneg=False) 70 while int(ui) > int(MAX_ULONG): 71 ui = gen_random_mpz(mindigits=1, maxdigits=20, allowneg=False) 72 return ui 73 74 75def gen_digits(length): 76 if length == 1: 77 i = random.randint(1, 9) 78 else: 79 digits = [random.randint(1, 9)] + [ 80 random.randint(0, 9) for x in range(length - 1) 81 ] 82 digits = map(str, digits) 83 i = "".join(digits) 84 return str(i) 85 86 87def gen_mpzs(mindigits=1, maxdigits=100, count=10): 88 return [ 89 gen_random_mpz(mindigits=mindigits, maxdigits=maxdigits) for x in range(count) 90 ] 91 92 93default_count = 10 94 95 96def gen_sis(count=default_count): 97 return [gen_random_si() for x in range(count)] 98 99 100def gen_uis(count=default_count): 101 return [gen_random_ui() for x in range(count)] 102 103 104def gen_small_mpzs(count=default_count): 105 return gen_mpzs(mindigits=1, maxdigits=4, count=count) 106 107 108def is_small_mpz(s): 109 return len(s) >= 1 and len(s) <= 4 110 111 112def gen_medium_mpzs(count=default_count): 113 return gen_mpzs(mindigits=5, maxdigits=20, count=count) 114 115 116def is_medium_mpz(s): 117 return len(s) >= 5 and len(s) <= 20 118 119 120def gen_large_mpzs(count=default_count): 121 return gen_mpzs(mindigits=21, maxdigits=100, count=count) 122 123 124def is_large_mpz(s): 125 return len(s) >= 21 126 127 128def gen_mpz_spread(count=default_count): 129 return gen_small_mpzs(count) + gen_medium_mpzs(count) + gen_large_mpzs(count) 130 131 132def gen_mpz_args(count=default_count): 133 return mpz_std_list + gen_mpz_spread(count) 134 135 136def gen_mpq_args(count=4): 137 nums = zero_one + gen_mpz_spread(count) 138 dens = ["1"] + gen_mpz_spread(count) 139 return [n + "/" + d for n in nums for d in dens if int(d) != 0] 140 141 142def gen_si_args(): 143 return si_std_list + gen_sis() 144 145 146def gen_ui_args(): 147 return ui_std_list + gen_uis() 148 149 150def gen_list_for_type(t, is_write_only): 151 if (t == gmpapi.mpz_t or t == gmpapi.mpq_t) and is_write_only: 152 return ["0"] 153 elif t == gmpapi.mpz_t: 154 return gen_mpz_args() 155 elif t == gmpapi.ilong: 156 return gen_si_args() 157 elif t == gmpapi.ulong: 158 return gen_ui_args() 159 elif t == gmpapi.mpq_t: 160 return gen_mpq_args() 161 else: 162 raise RuntimeError("Unknown type: {}".format(t)) 163 164 165def gen_args(api): 166 if api.custom_test or api.name in custom: 167 return custom[api.name](api) 168 types = api.params 169 if len(types) == 1: 170 return [[a] for a in gen_list_for_type(types[0], api.is_write_only(0))] 171 elif len(types) == 2: 172 t1 = gen_list_for_type(types[0], api.is_write_only(0)) 173 t2 = gen_list_for_type(types[1], api.is_write_only(1)) 174 return [(a, b) for a in t1 for b in t2] 175 elif len(types) == 3: 176 t1 = gen_list_for_type(types[0], api.is_write_only(0)) 177 t2 = gen_list_for_type(types[1], api.is_write_only(1)) 178 t3 = gen_list_for_type(types[2], api.is_write_only(2)) 179 return [(a, b, c) for a in t1 for b in t2 for c in t3] 180 elif len(types) == 4: 181 t1 = gen_list_for_type(types[0], api.is_write_only(0)) 182 t2 = gen_list_for_type(types[1], api.is_write_only(1)) 183 t3 = gen_list_for_type(types[2], api.is_write_only(2)) 184 t4 = gen_list_for_type(types[3], api.is_write_only(3)) 185 return [(a, b, c, d) for a in t1 for b in t2 for c in t3 for d in t4] 186 else: 187 raise RuntimeError("Too many args: {}".format(len(types))) 188 189 190################################################################### 191# 192# Fixup and massage random data for better test coverage 193# 194################################################################### 195def mul_mpzs(a, b): 196 return str(int(a) * int(b)) 197 198 199def mpz_divexact_data(args): 200 # set n = n * d 201 divisible = mul_mpzs(args[1], (args[2])) 202 return [(args[0], divisible, args[2])] 203 204 205def mpz_divisible_p_data(args): 206 (n, d) = get_div_data(args[0], args[1], rate=1.0) 207 return [(n, d), (args[0], args[1])] 208 209 210def mpz_div3_data(args): 211 q = args[0] 212 (n, d) = get_div_data(args[1], args[2], rate=1.0) 213 return [(q, n, d), (q, args[1], args[2])] 214 215 216def mpz_pow_data(args, alwaysallowbase1=True): 217 base = int(args[1]) 218 exp = int(args[2]) 219 # allow special numbers 220 if base == 0 or exp == 0 or exp == 1: 221 return [args] 222 if base == 1 and alwaysallowbase1: 223 return [args] 224 225 # disallow too big numbers 226 if base > 1000 or base < -1000: 227 base = gen_random_mpz(maxdigits=3) 228 if exp > 1000: 229 exp = gen_random_mpz(maxdigits=3, allowneg=False) 230 231 return [(args[0], str(base), str(exp))] 232 233 234def mpz_mul_2exp_data(args): 235 return mpz_pow_data(args, alwaysallowbase1=False) 236 237 238def mpz_gcd_data(args): 239 r = args[0] 240 a = args[1] 241 b = args[2] 242 s_ = gen_small_mpzs(1)[0] 243 m_ = gen_medium_mpzs(1)[0] 244 l_ = gen_large_mpzs(1)[0] 245 246 return [ 247 (r, a, b), 248 (r, mul_mpzs(a, b), b), 249 (r, mul_mpzs(a, s_), mul_mpzs(b, s_)), 250 (r, mul_mpzs(a, m_), mul_mpzs(b, m_)), 251 (r, mul_mpzs(a, l_), mul_mpzs(b, l_)), 252 ] 253 254 255def mpz_export_data(api): 256 rop = ["0"] 257 countp = ["0"] 258 order = ["-1", "1"] 259 size = ["1", "2", "4", "8"] 260 endian = ["0"] 261 nails = ["0"] 262 ops = gen_mpz_args(1000) + gen_mpzs(count=100, mindigits=100, maxdigits=1000) 263 264 args = [] 265 for r in rop: 266 for c in countp: 267 for o in order: 268 for s in size: 269 for e in endian: 270 for n in nails: 271 for op in ops: 272 args.append((r, c, o, s, e, n, op)) 273 return args 274 275 276def mpz_sizeinbase_data(api): 277 bases = list(map(str, range(2, 37))) 278 ops = gen_mpz_args(1000) + gen_mpzs(count=1000, mindigits=100, maxdigits=2000) 279 return [(op, b) for op in ops for b in bases] 280 281 282def get_str_data(ty): 283 bases = list(range(2, 37)) + list(range(-2, -37, -1)) 284 bases = list(map(str, bases)) 285 if ty == gmpapi.mpz_t: 286 ops = gen_mpz_args(1000) 287 elif ty == gmpapi.mpq_t: 288 ops = gen_mpq_args(20) 289 else: 290 raise RuntimeError("Unsupported get_str type: " + str(ty)) 291 return [("NULL", b, op) for b in bases for op in ops] 292 293 294def mpz_get_str_data(api): 295 return get_str_data(gmpapi.mpz_t) 296 297 298def mpq_get_str_data(api): 299 return get_str_data(gmpapi.mpq_t) 300 301 302def mpq_set_str_data(api): 303 args = gen_mpq_args(20) + gen_mpz_args() 304 # zero does not match results exactly because the 305 # results are not canonicalized first. We choose to 306 # exclude zero from test results. The other option is 307 # to canonicalize the results after parsing the strings. 308 # Instead we exclude zero so that we can independently 309 # test correctness of set_str and canonicalization 310 nonzero = [] 311 for arg in args: 312 if "/" in arg: 313 pos = arg.find("/") 314 if int(arg[:pos]) != 0: 315 nonzero.append(arg) 316 elif int(arg) != 0: 317 nonzero.append(arg) 318 319 return [("0", q, "10") for q in nonzero] 320 321 322def get_div_data(n, d, rate=0.2): 323 """Generate some inputs that are perfectly divisible""" 324 if random.random() < rate: 325 n = mul_mpzs(n, d) 326 return (n, d) 327 328 329def allow(name, args): 330 if name not in blacklists: 331 return True 332 filters = blacklists[name] 333 for (pos, disallow) in filters: 334 if args[pos] in disallow: 335 return False 336 return True 337 338 339def fixup_args(name, args): 340 if name not in fixups: 341 return [args] 342 return fixups[name](args) 343 344 345# list of values to be excluded for various api calls 346# list format is (pos, [list of values to exclude]) 347blacklists = { 348 "mpz_cdiv_q": [(2, ["0", "-0"])], 349 "mpz_fdiv_q": [(2, ["0", "-0"])], 350 "mpz_fdiv_r": [(2, ["0", "-0"])], 351 "mpz_tdiv_q": [(2, ["0", "-0"])], 352 "mpz_fdiv_q_ui": [(2, ["0", "-0"])], 353 "mpz_divexact": [(2, ["0", "-0"])], 354 "mpz_divisible_p": [(1, ["0", "-0"])], 355 "mpz_divexact_ui": [(2, ["0", "-0"])], 356 "mpq_set_ui": [(2, ["0", "-0"])], 357} 358 359fixups = { 360 "mpz_divexact": mpz_divexact_data, 361 "mpz_divisible_p": mpz_divisible_p_data, 362 "mpz_cdiv_q": mpz_div3_data, 363 "mpz_fdiv_q": mpz_div3_data, 364 "mpz_fdiv_r": mpz_div3_data, 365 "mpz_tdiv_q": mpz_div3_data, 366 "mpz_fdiv_q_ui": mpz_div3_data, 367 "mpz_divexact_ui": mpz_divexact_data, 368 "mpz_pow_ui": mpz_pow_data, 369 "mpz_gcd": mpz_gcd_data, 370 "mpz_lcm": mpz_gcd_data, 371 "mpz_mul_2exp": mpz_mul_2exp_data, 372} 373 374custom = { 375 "mpz_export": mpz_export_data, 376 "mpz_import": mpz_export_data, 377 "mpz_sizeinbase": mpz_sizeinbase_data, 378 "mpz_get_str": mpz_get_str_data, 379 "mpq_set_str": mpq_set_str_data, 380 "mpq_get_str": mpq_get_str_data, 381} 382 383if __name__ == "__main__": 384 # apis = [gmpapi.get_api("mpq_set_str"),] 385 apis = gmpapi.apis 386 for api in apis: 387 tests = gen_args(api) 388 for args in tests: 389 expanded_args = fixup_args(api.name, args) 390 for args in expanded_args: 391 if allow(api.name, args): 392 print("{}|{}".format(api.name, ",".join(args))) 393