#!/usr/bin/env python import sys import gmpapi from gmpapi import void from gmpapi import ilong from gmpapi import iint from gmpapi import ulong from gmpapi import mpz_t from gmpapi import size_t from gmpapi import charp from gmpapi import mpq_t class APITest(object): def __init__(self, gmpapi): self.api = gmpapi def test_prefix(self): return "test" def test_param_name(self, ty, i): if ty == mpz_t: pname = "p_zs" elif ty == ilong: pname = "p_si" elif ty == ulong: pname = "p_ui" elif ty == iint: pname = "p_i" elif ty == charp: pname = "p_cs" elif ty == mpq_t: pname = "p_qs" else: raise RuntimeError("Unknown param type: " + str(ty)) return pname + str(i) def test_param_type(self, ty): if ty == mpz_t or ty == mpq_t: pty_name = "char *" else: pty_name = str(ty) return pty_name def test_var_name(self, ty, i): if ty == mpz_t: vname = "v_z" elif ty == ilong: vname = "v_si" elif ty == ulong: vname = "v_ui" elif ty == iint: vname = "v_i" elif ty == size_t: vname = "v_st" elif ty == charp: vname = "v_cs" elif ty == mpq_t: vname = "v_q" else: raise RuntimeError("Unknown param type: " + str(ty)) return vname + str(i) def test_var_type(self, ty): if ty == mpz_t: return self.mpz_type() elif ty == mpq_t: return self.mpq_type() else: return str(ty) def init_var_from_param(self, ty, var, param): code = "\t" if ty == mpz_t or ty == mpq_t: code += self.api_call_prefix(ty) + "init(" + var + ");\n\t" code += ( self.api_call_prefix(ty) + "set_str(" + ",".join([var, param, "10"]) + ")" ) if ty == mpq_t: code += ";\n\t" code += self.api_call_prefix(ty) + "canonicalize(" + var + ")" else: code += var + "=" + param return code def init_vars_from_params(self): code = "" for (i, p) in enumerate(self.api.params): param = self.test_param_name(p, i) code += "\t" code += self.test_var_type(p) + " " var = self.test_var_name(p, i) code += var + ";\n" code += self.init_var_from_param(p, var, param) + ";\n\n" return code def make_api_call(self): bare_name = self.api.name.replace("mpz_", "", 1).replace("mpq_", "", 1) call_params = [ self.test_var_name(p, i) for (i, p) in enumerate(self.api.params) ] ret = "\t" ret_ty = self.api.ret_ty if ret_ty != void: ret += ( self.test_var_type(ret_ty) + " " + self.test_var_name(ret_ty, "_ret") + " = " ) # call mpq or mpz function if self.api.name.startswith("mpz_"): prefix = self.api_call_prefix(mpz_t) else: prefix = self.api_call_prefix(mpq_t) return ret + prefix + bare_name + "(" + ",".join(call_params) + ");\n" def normalize_cmp(self, ty): cmpval = self.test_var_name(ty, "_ret") code = "" code += """ if ({var} > 0) {var} = 1; else if ({var} < 0) {var} = -1;\n\t """.format( var=cmpval ) return code def extract_result(self, ty, pos): code = "" if ty == mpz_t or ty == mpq_t: var = self.test_var_name(ty, pos) code += self.api_call_prefix(ty) + "get_str(out+offset, 10," + var + ");\n" code += "\toffset = offset + strlen(out); " code += "out[offset] = ' '; out[offset+1] = 0; offset += 1;" else: assert pos == -1, "expected a return value, not a param value" if ty == ilong: var = self.test_var_name(ty, "_ret") code += 'offset = sprintf(out+offset, " %ld ", ' + var + ");" elif ty == ulong: var = self.test_var_name(ty, "_ret") code += 'offset = sprintf(out+offset, " %lu ", ' + var + ");" elif ty == iint: var = self.test_var_name(ty, "_ret") code += 'offset = sprintf(out+offset, " %d ", ' + var + ");" elif ty == size_t: var = self.test_var_name(ty, "_ret") code += 'offset = sprintf(out+offset, " %zu ", ' + var + ");" elif ty == charp: var = self.test_var_name(ty, "_ret") code += 'offset = sprintf(out+offset, " %s ", ' + var + ");" else: raise RuntimeError("Unknown param type: " + str(ty)) return code def extract_results(self): ret_ty = self.api.ret_ty code = "\tint offset = 0;\n\t" # normalize cmp return values if ret_ty == iint and "cmp" in self.api.name: code += self.normalize_cmp(ret_ty) # call canonicalize for mpq_set_ui if self.api.name == "mpq_set_ui": code += ( self.api_call_prefix(mpq_t) + "canonicalize(" + self.test_var_name(mpq_t, 0) + ");\n\t" ) # get return value if ret_ty != void: code += self.extract_result(ret_ty, -1) + "\n" # get out param values for pos in self.api.out_params: code += "\t" code += self.extract_result(self.api.params[pos], pos) + "\n" return code + "\n" def clear_local_vars(self): code = "" for (i, p) in enumerate(self.api.params): if p == mpz_t or p == mpq_t: var = self.test_var_name(p, i) code += "\t" + self.api_call_prefix(p) + "clear(" + var + ");\n" return code def print_test_code(self, outf): api = self.api params = [ self.test_param_type(p) + " " + self.test_param_name(p, i) for (i, p) in enumerate(api.params) ] code = "void {}_{}(char *out, {})".format( self.test_prefix(), api.name, ", ".join(params) ) code += "{\n" code += self.init_vars_from_params() code += self.make_api_call() code += self.extract_results() code += self.clear_local_vars() code += "}\n" outf.write(code) outf.write("\n") class GMPTest(APITest): def __init__(self, gmpapi): super(GMPTest, self).__init__(gmpapi) def api_call_prefix(self, kind): if kind == mpz_t: return "mpz_" elif kind == mpq_t: return "mpq_" else: raise RuntimeError("Unknown call kind: " + str(kind)) def mpz_type(self): return "mpz_t" def mpq_type(self): return "mpq_t" class ImathTest(APITest): def __init__(self, gmpapi): super(ImathTest, self).__init__(gmpapi) def api_call_prefix(self, kind): if kind == mpz_t: return "impz_" elif kind == mpq_t: return "impq_" else: raise RuntimeError("Unknown call kind: " + str(kind)) def mpz_type(self): return "impz_t" def mpq_type(self): return "impq_t" def print_gmp_header(outf): code = "" code += "#include \n" code += "#include \n" code += "#include \n" code += '#include "gmp_custom_test.c"\n' outf.write(code) def print_imath_header(outf): code = "" code += "#include \n" code += "#include \n" code += "#include \n" code += "typedef mpz_t impz_t[1];\n" code += "typedef mpq_t impq_t[1];\n" code += '#include "imath_custom_test.c"\n' outf.write(code) def print_gmp_tests(outf): print_gmp_header(outf) for api in gmpapi.apis: if not api.custom_test: GMPTest(api).print_test_code(outf) def print_imath_tests(outf): print_imath_header(outf) for api in gmpapi.apis: if not api.custom_test: ImathTest(api).print_test_code(outf) def main(): test = sys.argv[1] if test == "gmp": print_gmp_tests(sys.stdout) elif test == "imath": print_imath_tests(sys.stdout) if __name__ == "__main__": main()