xref: /llvm-project/polly/lib/External/isl/imath/tests/gmp-compat-test/genctest.py (revision f98ee40f4b5d7474fc67e82824bf6abbaedb7b1c)
1#!/usr/bin/env python
2import sys
3import gmpapi
4from gmpapi import void
5from gmpapi import ilong
6from gmpapi import iint
7from gmpapi import ulong
8from gmpapi import mpz_t
9from gmpapi import size_t
10from gmpapi import charp
11from gmpapi import mpq_t
12
13
14class APITest(object):
15    def __init__(self, gmpapi):
16        self.api = gmpapi
17
18    def test_prefix(self):
19        return "test"
20
21    def test_param_name(self, ty, i):
22        if ty == mpz_t:
23            pname = "p_zs"
24        elif ty == ilong:
25            pname = "p_si"
26        elif ty == ulong:
27            pname = "p_ui"
28        elif ty == iint:
29            pname = "p_i"
30        elif ty == charp:
31            pname = "p_cs"
32        elif ty == mpq_t:
33            pname = "p_qs"
34        else:
35            raise RuntimeError("Unknown param type: " + str(ty))
36        return pname + str(i)
37
38    def test_param_type(self, ty):
39        if ty == mpz_t or ty == mpq_t:
40            pty_name = "char *"
41        else:
42            pty_name = str(ty)
43        return pty_name
44
45    def test_var_name(self, ty, i):
46        if ty == mpz_t:
47            vname = "v_z"
48        elif ty == ilong:
49            vname = "v_si"
50        elif ty == ulong:
51            vname = "v_ui"
52        elif ty == iint:
53            vname = "v_i"
54        elif ty == size_t:
55            vname = "v_st"
56        elif ty == charp:
57            vname = "v_cs"
58        elif ty == mpq_t:
59            vname = "v_q"
60        else:
61            raise RuntimeError("Unknown param type: " + str(ty))
62        return vname + str(i)
63
64    def test_var_type(self, ty):
65        if ty == mpz_t:
66            return self.mpz_type()
67        elif ty == mpq_t:
68            return self.mpq_type()
69        else:
70            return str(ty)
71
72    def init_var_from_param(self, ty, var, param):
73        code = "\t"
74        if ty == mpz_t or ty == mpq_t:
75            code += self.api_call_prefix(ty) + "init(" + var + ");\n\t"
76            code += (
77                self.api_call_prefix(ty)
78                + "set_str("
79                + ",".join([var, param, "10"])
80                + ")"
81            )
82            if ty == mpq_t:
83                code += ";\n\t"
84                code += self.api_call_prefix(ty) + "canonicalize(" + var + ")"
85        else:
86            code += var + "=" + param
87        return code
88
89    def init_vars_from_params(self):
90        code = ""
91        for (i, p) in enumerate(self.api.params):
92            param = self.test_param_name(p, i)
93            code += "\t"
94            code += self.test_var_type(p) + " "
95            var = self.test_var_name(p, i)
96            code += var + ";\n"
97            code += self.init_var_from_param(p, var, param) + ";\n\n"
98        return code
99
100    def make_api_call(self):
101        bare_name = self.api.name.replace("mpz_", "", 1).replace("mpq_", "", 1)
102        call_params = [
103            self.test_var_name(p, i) for (i, p) in enumerate(self.api.params)
104        ]
105        ret = "\t"
106        ret_ty = self.api.ret_ty
107        if ret_ty != void:
108            ret += (
109                self.test_var_type(ret_ty)
110                + " "
111                + self.test_var_name(ret_ty, "_ret")
112                + " = "
113            )
114        # call mpq or mpz function
115        if self.api.name.startswith("mpz_"):
116            prefix = self.api_call_prefix(mpz_t)
117        else:
118            prefix = self.api_call_prefix(mpq_t)
119        return ret + prefix + bare_name + "(" + ",".join(call_params) + ");\n"
120
121    def normalize_cmp(self, ty):
122        cmpval = self.test_var_name(ty, "_ret")
123        code = ""
124        code += """
125	if ({var} > 0)
126	  {var} = 1;
127	else if ({var} < 0)
128	  {var} = -1;\n\t
129""".format(
130            var=cmpval
131        )
132        return code
133
134    def extract_result(self, ty, pos):
135        code = ""
136        if ty == mpz_t or ty == mpq_t:
137            var = self.test_var_name(ty, pos)
138            code += self.api_call_prefix(ty) + "get_str(out+offset, 10," + var + ");\n"
139            code += "\toffset = offset + strlen(out); "
140            code += "out[offset] = ' '; out[offset+1] = 0; offset += 1;"
141        else:
142            assert pos == -1, "expected a return value, not a param value"
143            if ty == ilong:
144                var = self.test_var_name(ty, "_ret")
145                code += 'offset = sprintf(out+offset, " %ld ", ' + var + ");"
146            elif ty == ulong:
147                var = self.test_var_name(ty, "_ret")
148                code += 'offset = sprintf(out+offset, " %lu ", ' + var + ");"
149            elif ty == iint:
150                var = self.test_var_name(ty, "_ret")
151                code += 'offset = sprintf(out+offset, " %d ", ' + var + ");"
152            elif ty == size_t:
153                var = self.test_var_name(ty, "_ret")
154                code += 'offset = sprintf(out+offset, " %zu ", ' + var + ");"
155            elif ty == charp:
156                var = self.test_var_name(ty, "_ret")
157                code += 'offset = sprintf(out+offset, " %s ", ' + var + ");"
158            else:
159                raise RuntimeError("Unknown param type: " + str(ty))
160        return code
161
162    def extract_results(self):
163        ret_ty = self.api.ret_ty
164        code = "\tint offset = 0;\n\t"
165
166        # normalize cmp return values
167        if ret_ty == iint and "cmp" in self.api.name:
168            code += self.normalize_cmp(ret_ty)
169
170        # call canonicalize for mpq_set_ui
171        if self.api.name == "mpq_set_ui":
172            code += (
173                self.api_call_prefix(mpq_t)
174                + "canonicalize("
175                + self.test_var_name(mpq_t, 0)
176                + ");\n\t"
177            )
178
179        # get return value
180        if ret_ty != void:
181            code += self.extract_result(ret_ty, -1) + "\n"
182
183        # get out param values
184        for pos in self.api.out_params:
185            code += "\t"
186            code += self.extract_result(self.api.params[pos], pos) + "\n"
187
188        return code + "\n"
189
190    def clear_local_vars(self):
191        code = ""
192        for (i, p) in enumerate(self.api.params):
193            if p == mpz_t or p == mpq_t:
194                var = self.test_var_name(p, i)
195                code += "\t" + self.api_call_prefix(p) + "clear(" + var + ");\n"
196        return code
197
198    def print_test_code(self, outf):
199        api = self.api
200        params = [
201            self.test_param_type(p) + " " + self.test_param_name(p, i)
202            for (i, p) in enumerate(api.params)
203        ]
204        code = "void {}_{}(char *out, {})".format(
205            self.test_prefix(), api.name, ", ".join(params)
206        )
207        code += "{\n"
208        code += self.init_vars_from_params()
209        code += self.make_api_call()
210        code += self.extract_results()
211        code += self.clear_local_vars()
212        code += "}\n"
213        outf.write(code)
214        outf.write("\n")
215
216
217class GMPTest(APITest):
218    def __init__(self, gmpapi):
219        super(GMPTest, self).__init__(gmpapi)
220
221    def api_call_prefix(self, kind):
222        if kind == mpz_t:
223            return "mpz_"
224        elif kind == mpq_t:
225            return "mpq_"
226        else:
227            raise RuntimeError("Unknown call kind: " + str(kind))
228
229    def mpz_type(self):
230        return "mpz_t"
231
232    def mpq_type(self):
233        return "mpq_t"
234
235
236class ImathTest(APITest):
237    def __init__(self, gmpapi):
238        super(ImathTest, self).__init__(gmpapi)
239
240    def api_call_prefix(self, kind):
241        if kind == mpz_t:
242            return "impz_"
243        elif kind == mpq_t:
244            return "impq_"
245        else:
246            raise RuntimeError("Unknown call kind: " + str(kind))
247
248    def mpz_type(self):
249        return "impz_t"
250
251    def mpq_type(self):
252        return "impq_t"
253
254
255def print_gmp_header(outf):
256    code = ""
257    code += "#include <gmp.h>\n"
258    code += "#include <stdio.h>\n"
259    code += "#include <string.h>\n"
260    code += '#include "gmp_custom_test.c"\n'
261    outf.write(code)
262
263
264def print_imath_header(outf):
265    code = ""
266    code += "#include <gmp_compat.h>\n"
267    code += "#include <stdio.h>\n"
268    code += "#include <string.h>\n"
269    code += "typedef mpz_t impz_t[1];\n"
270    code += "typedef mpq_t impq_t[1];\n"
271    code += '#include "imath_custom_test.c"\n'
272    outf.write(code)
273
274
275def print_gmp_tests(outf):
276    print_gmp_header(outf)
277    for api in gmpapi.apis:
278        if not api.custom_test:
279            GMPTest(api).print_test_code(outf)
280
281
282def print_imath_tests(outf):
283    print_imath_header(outf)
284    for api in gmpapi.apis:
285        if not api.custom_test:
286            ImathTest(api).print_test_code(outf)
287
288
289def main():
290    test = sys.argv[1]
291
292    if test == "gmp":
293        print_gmp_tests(sys.stdout)
294    elif test == "imath":
295        print_imath_tests(sys.stdout)
296
297
298if __name__ == "__main__":
299    main()
300