1#!/usr/bin/env python 2import sys 3import gmpapi 4from gmpapi import void 5from gmpapi import ilong 6from gmpapi import iint 7from gmpapi import ulong 8from gmpapi import mpz_t 9from gmpapi import size_t 10from gmpapi import charp 11from gmpapi import mpq_t 12 13 14class APITest(object): 15 def __init__(self, gmpapi): 16 self.api = gmpapi 17 18 def test_prefix(self): 19 return "test" 20 21 def test_param_name(self, ty, i): 22 if ty == mpz_t: 23 pname = "p_zs" 24 elif ty == ilong: 25 pname = "p_si" 26 elif ty == ulong: 27 pname = "p_ui" 28 elif ty == iint: 29 pname = "p_i" 30 elif ty == charp: 31 pname = "p_cs" 32 elif ty == mpq_t: 33 pname = "p_qs" 34 else: 35 raise RuntimeError("Unknown param type: " + str(ty)) 36 return pname + str(i) 37 38 def test_param_type(self, ty): 39 if ty == mpz_t or ty == mpq_t: 40 pty_name = "char *" 41 else: 42 pty_name = str(ty) 43 return pty_name 44 45 def test_var_name(self, ty, i): 46 if ty == mpz_t: 47 vname = "v_z" 48 elif ty == ilong: 49 vname = "v_si" 50 elif ty == ulong: 51 vname = "v_ui" 52 elif ty == iint: 53 vname = "v_i" 54 elif ty == size_t: 55 vname = "v_st" 56 elif ty == charp: 57 vname = "v_cs" 58 elif ty == mpq_t: 59 vname = "v_q" 60 else: 61 raise RuntimeError("Unknown param type: " + str(ty)) 62 return vname + str(i) 63 64 def test_var_type(self, ty): 65 if ty == mpz_t: 66 return self.mpz_type() 67 elif ty == mpq_t: 68 return self.mpq_type() 69 else: 70 return str(ty) 71 72 def init_var_from_param(self, ty, var, param): 73 code = "\t" 74 if ty == mpz_t or ty == mpq_t: 75 code += self.api_call_prefix(ty) + "init(" + var + ");\n\t" 76 code += ( 77 self.api_call_prefix(ty) 78 + "set_str(" 79 + ",".join([var, param, "10"]) 80 + ")" 81 ) 82 if ty == mpq_t: 83 code += ";\n\t" 84 code += self.api_call_prefix(ty) + "canonicalize(" + var + ")" 85 else: 86 code += var + "=" + param 87 return code 88 89 def init_vars_from_params(self): 90 code = "" 91 for (i, p) in enumerate(self.api.params): 92 param = self.test_param_name(p, i) 93 code += "\t" 94 code += self.test_var_type(p) + " " 95 var = self.test_var_name(p, i) 96 code += var + ";\n" 97 code += self.init_var_from_param(p, var, param) + ";\n\n" 98 return code 99 100 def make_api_call(self): 101 bare_name = self.api.name.replace("mpz_", "", 1).replace("mpq_", "", 1) 102 call_params = [ 103 self.test_var_name(p, i) for (i, p) in enumerate(self.api.params) 104 ] 105 ret = "\t" 106 ret_ty = self.api.ret_ty 107 if ret_ty != void: 108 ret += ( 109 self.test_var_type(ret_ty) 110 + " " 111 + self.test_var_name(ret_ty, "_ret") 112 + " = " 113 ) 114 # call mpq or mpz function 115 if self.api.name.startswith("mpz_"): 116 prefix = self.api_call_prefix(mpz_t) 117 else: 118 prefix = self.api_call_prefix(mpq_t) 119 return ret + prefix + bare_name + "(" + ",".join(call_params) + ");\n" 120 121 def normalize_cmp(self, ty): 122 cmpval = self.test_var_name(ty, "_ret") 123 code = "" 124 code += """ 125 if ({var} > 0) 126 {var} = 1; 127 else if ({var} < 0) 128 {var} = -1;\n\t 129""".format( 130 var=cmpval 131 ) 132 return code 133 134 def extract_result(self, ty, pos): 135 code = "" 136 if ty == mpz_t or ty == mpq_t: 137 var = self.test_var_name(ty, pos) 138 code += self.api_call_prefix(ty) + "get_str(out+offset, 10," + var + ");\n" 139 code += "\toffset = offset + strlen(out); " 140 code += "out[offset] = ' '; out[offset+1] = 0; offset += 1;" 141 else: 142 assert pos == -1, "expected a return value, not a param value" 143 if ty == ilong: 144 var = self.test_var_name(ty, "_ret") 145 code += 'offset = sprintf(out+offset, " %ld ", ' + var + ");" 146 elif ty == ulong: 147 var = self.test_var_name(ty, "_ret") 148 code += 'offset = sprintf(out+offset, " %lu ", ' + var + ");" 149 elif ty == iint: 150 var = self.test_var_name(ty, "_ret") 151 code += 'offset = sprintf(out+offset, " %d ", ' + var + ");" 152 elif ty == size_t: 153 var = self.test_var_name(ty, "_ret") 154 code += 'offset = sprintf(out+offset, " %zu ", ' + var + ");" 155 elif ty == charp: 156 var = self.test_var_name(ty, "_ret") 157 code += 'offset = sprintf(out+offset, " %s ", ' + var + ");" 158 else: 159 raise RuntimeError("Unknown param type: " + str(ty)) 160 return code 161 162 def extract_results(self): 163 ret_ty = self.api.ret_ty 164 code = "\tint offset = 0;\n\t" 165 166 # normalize cmp return values 167 if ret_ty == iint and "cmp" in self.api.name: 168 code += self.normalize_cmp(ret_ty) 169 170 # call canonicalize for mpq_set_ui 171 if self.api.name == "mpq_set_ui": 172 code += ( 173 self.api_call_prefix(mpq_t) 174 + "canonicalize(" 175 + self.test_var_name(mpq_t, 0) 176 + ");\n\t" 177 ) 178 179 # get return value 180 if ret_ty != void: 181 code += self.extract_result(ret_ty, -1) + "\n" 182 183 # get out param values 184 for pos in self.api.out_params: 185 code += "\t" 186 code += self.extract_result(self.api.params[pos], pos) + "\n" 187 188 return code + "\n" 189 190 def clear_local_vars(self): 191 code = "" 192 for (i, p) in enumerate(self.api.params): 193 if p == mpz_t or p == mpq_t: 194 var = self.test_var_name(p, i) 195 code += "\t" + self.api_call_prefix(p) + "clear(" + var + ");\n" 196 return code 197 198 def print_test_code(self, outf): 199 api = self.api 200 params = [ 201 self.test_param_type(p) + " " + self.test_param_name(p, i) 202 for (i, p) in enumerate(api.params) 203 ] 204 code = "void {}_{}(char *out, {})".format( 205 self.test_prefix(), api.name, ", ".join(params) 206 ) 207 code += "{\n" 208 code += self.init_vars_from_params() 209 code += self.make_api_call() 210 code += self.extract_results() 211 code += self.clear_local_vars() 212 code += "}\n" 213 outf.write(code) 214 outf.write("\n") 215 216 217class GMPTest(APITest): 218 def __init__(self, gmpapi): 219 super(GMPTest, self).__init__(gmpapi) 220 221 def api_call_prefix(self, kind): 222 if kind == mpz_t: 223 return "mpz_" 224 elif kind == mpq_t: 225 return "mpq_" 226 else: 227 raise RuntimeError("Unknown call kind: " + str(kind)) 228 229 def mpz_type(self): 230 return "mpz_t" 231 232 def mpq_type(self): 233 return "mpq_t" 234 235 236class ImathTest(APITest): 237 def __init__(self, gmpapi): 238 super(ImathTest, self).__init__(gmpapi) 239 240 def api_call_prefix(self, kind): 241 if kind == mpz_t: 242 return "impz_" 243 elif kind == mpq_t: 244 return "impq_" 245 else: 246 raise RuntimeError("Unknown call kind: " + str(kind)) 247 248 def mpz_type(self): 249 return "impz_t" 250 251 def mpq_type(self): 252 return "impq_t" 253 254 255def print_gmp_header(outf): 256 code = "" 257 code += "#include <gmp.h>\n" 258 code += "#include <stdio.h>\n" 259 code += "#include <string.h>\n" 260 code += '#include "gmp_custom_test.c"\n' 261 outf.write(code) 262 263 264def print_imath_header(outf): 265 code = "" 266 code += "#include <gmp_compat.h>\n" 267 code += "#include <stdio.h>\n" 268 code += "#include <string.h>\n" 269 code += "typedef mpz_t impz_t[1];\n" 270 code += "typedef mpq_t impq_t[1];\n" 271 code += '#include "imath_custom_test.c"\n' 272 outf.write(code) 273 274 275def print_gmp_tests(outf): 276 print_gmp_header(outf) 277 for api in gmpapi.apis: 278 if not api.custom_test: 279 GMPTest(api).print_test_code(outf) 280 281 282def print_imath_tests(outf): 283 print_imath_header(outf) 284 for api in gmpapi.apis: 285 if not api.custom_test: 286 ImathTest(api).print_test_code(outf) 287 288 289def main(): 290 test = sys.argv[1] 291 292 if test == "gmp": 293 print_gmp_tests(sys.stdout) 294 elif test == "imath": 295 print_imath_tests(sys.stdout) 296 297 298if __name__ == "__main__": 299 main() 300