xref: /llvm-project/polly/lib/External/isl/isl_test_python.py (revision f98ee40f4b5d7474fc67e82824bf6abbaedb7b1c)
1# Copyright 2016-2017 Tobias Grosser
2#
3# Use of this software is governed by the MIT license
4#
5# Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich
6
7import sys
8import isl
9
10# Test that isl objects can be constructed.
11#
12# This tests:
13#  - construction from a string
14#  - construction from an integer
15#  - static constructor without a parameter
16#  - conversion construction
17#  - construction of empty union set
18#
19#  The tests to construct from integers and strings cover functionality that
20#  is also tested in the parameter type tests, but here the presence of
21#  multiple overloaded constructors and overload resolution is tested.
22#
23def test_constructors():
24    zero1 = isl.val("0")
25    assert zero1.is_zero()
26
27    zero2 = isl.val(0)
28    assert zero2.is_zero()
29
30    zero3 = isl.val.zero()
31    assert zero3.is_zero()
32
33    bs = isl.basic_set("{ [1] }")
34    result = isl.set("{ [1] }")
35    s = isl.set(bs)
36    assert s.is_equal(result)
37
38    us = isl.union_set("{ A[1]; B[2, 3] }")
39    empty = isl.union_set.empty()
40    assert us.is_equal(us.union(empty))
41
42
43# Test integer function parameters for a particular integer value.
44#
45def test_int(i):
46    val_int = isl.val(i)
47    val_str = isl.val(str(i))
48    assert val_int.eq(val_str)
49
50
51# Test integer function parameters.
52#
53# Verify that extreme values and zero work.
54#
55def test_parameters_int():
56    test_int(sys.maxsize)
57    test_int(-sys.maxsize - 1)
58    test_int(0)
59
60
61# Test isl objects parameters.
62#
63# Verify that isl objects can be passed as lvalue and rvalue parameters.
64# Also verify that isl object parameters are automatically type converted if
65# there is an inheritance relation. Finally, test function calls without
66# any additional parameters, apart from the isl object on which
67# the method is called.
68#
69def test_parameters_obj():
70    a = isl.set("{ [0] }")
71    b = isl.set("{ [1] }")
72    c = isl.set("{ [2] }")
73    expected = isl.set("{ [i] : 0 <= i <= 2 }")
74
75    tmp = a.union(b)
76    res_lvalue_param = tmp.union(c)
77    assert res_lvalue_param.is_equal(expected)
78
79    res_rvalue_param = a.union(b).union(c)
80    assert res_rvalue_param.is_equal(expected)
81
82    a2 = isl.basic_set("{ [0] }")
83    assert a.is_equal(a2)
84
85    two = isl.val(2)
86    half = isl.val("1/2")
87    res_only_this_param = two.inv()
88    assert res_only_this_param.eq(half)
89
90
91# Test different kinds of parameters to be passed to functions.
92#
93# This includes integer and isl object parameters.
94#
95def test_parameters():
96    test_parameters_int()
97    test_parameters_obj()
98
99
100# Test that isl objects are returned correctly.
101#
102# This only tests that after combining two objects, the result is successfully
103# returned.
104#
105def test_return_obj():
106    one = isl.val("1")
107    two = isl.val("2")
108    three = isl.val("3")
109
110    res = one.add(two)
111
112    assert res.eq(three)
113
114
115# Test that integer values are returned correctly.
116#
117def test_return_int():
118    one = isl.val("1")
119    neg_one = isl.val("-1")
120    zero = isl.val("0")
121
122    assert one.sgn() > 0
123    assert neg_one.sgn() < 0
124    assert zero.sgn() == 0
125
126
127# Test that isl_bool values are returned correctly.
128#
129# In particular, check the conversion to bool in case of true and false.
130#
131def test_return_bool():
132    empty = isl.set("{ : false }")
133    univ = isl.set("{ : }")
134
135    b_true = empty.is_empty()
136    b_false = univ.is_empty()
137
138    assert b_true
139    assert not b_false
140
141
142# Test that strings are returned correctly.
143# Do so by calling overloaded isl.ast_build.from_expr methods.
144#
145def test_return_string():
146    context = isl.set("[n] -> { : }")
147    build = isl.ast_build.from_context(context)
148    pw_aff = isl.pw_aff("[n] -> { [n] }")
149    set = isl.set("[n] -> { : n >= 0 }")
150
151    expr = build.expr_from(pw_aff)
152    expected_string = "n"
153    assert expected_string == expr.to_C_str()
154
155    expr = build.expr_from(set)
156    expected_string = "n >= 0"
157    assert expected_string == expr.to_C_str()
158
159
160# Test that return values are handled correctly.
161#
162# Test that isl objects, integers, boolean values, and strings are
163# returned correctly.
164#
165def test_return():
166    test_return_obj()
167    test_return_int()
168    test_return_bool()
169    test_return_string()
170
171
172# A class that is used to test isl.id.user.
173#
174class S:
175    def __init__(self):
176        self.value = 42
177
178
179# Test isl.id.user.
180#
181# In particular, check that the object attached to an identifier
182# can be retrieved again.
183#
184def test_user():
185    id = isl.id("test", 5)
186    id2 = isl.id("test2")
187    id3 = isl.id("S", S())
188    assert id.user() == 5, f"unexpected user object {id.user()}"
189    assert id2.user() is None, f"unexpected user object {id2.user()}"
190    s = id3.user()
191    assert isinstance(s, S), f"unexpected user object {s}"
192    assert s.value == 42, f"unexpected user object {s}"
193
194
195# Test that foreach functions are modeled correctly.
196#
197# Verify that closures are correctly called as callback of a 'foreach'
198# function and that variables captured by the closure work correctly. Also
199# check that the foreach function handles exceptions thrown from
200# the closure and that it propagates the exception.
201#
202def test_foreach():
203    s = isl.set("{ [0]; [1]; [2] }")
204
205    list = []
206
207    def add(bs):
208        list.append(bs)
209
210    s.foreach_basic_set(add)
211
212    assert len(list) == 3
213    assert list[0].is_subset(s)
214    assert list[1].is_subset(s)
215    assert list[2].is_subset(s)
216    assert not list[0].is_equal(list[1])
217    assert not list[0].is_equal(list[2])
218    assert not list[1].is_equal(list[2])
219
220    def fail(bs):
221        raise Exception("fail")
222
223    caught = False
224    try:
225        s.foreach_basic_set(fail)
226    except:
227        caught = True
228    assert caught
229
230
231# Test the functionality of "foreach_scc" functions.
232#
233# In particular, test it on a list of elements that can be completely sorted
234# but where two of the elements ("a" and "b") are incomparable.
235#
236def test_foreach_scc():
237    list = isl.id_list(3)
238    sorted = [isl.id_list(3)]
239    data = {
240        "a": isl.map("{ [0] -> [1] }"),
241        "b": isl.map("{ [1] -> [0] }"),
242        "c": isl.map("{ [i = 0:1] -> [i] }"),
243    }
244    for k, v in data.items():
245        list = list.add(k)
246    id = data["a"].space().domain().identity_multi_pw_aff_on_domain()
247
248    def follows(a, b):
249        map = data[b.name()].apply_domain(data[a.name()])
250        return not map.lex_ge_at(id).is_empty()
251
252    def add_single(scc):
253        assert scc.size() == 1
254        sorted[0] = sorted[0].concat(scc)
255
256    list.foreach_scc(follows, add_single)
257    assert sorted[0].size() == 3
258    assert sorted[0].at(0).name() == "b"
259    assert sorted[0].at(1).name() == "c"
260    assert sorted[0].at(2).name() == "a"
261
262
263# Test the functionality of "every" functions.
264#
265# In particular, test the generic functionality and
266# test that exceptions are properly propagated.
267#
268def test_every():
269    us = isl.union_set("{ A[i]; B[j] }")
270
271    def is_empty(s):
272        return s.is_empty()
273
274    assert not us.every_set(is_empty)
275
276    def is_non_empty(s):
277        return not s.is_empty()
278
279    assert us.every_set(is_non_empty)
280
281    def in_A(s):
282        return s.is_subset(isl.set("{ A[x] }"))
283
284    assert not us.every_set(in_A)
285
286    def not_in_A(s):
287        return not s.is_subset(isl.set("{ A[x] }"))
288
289    assert not us.every_set(not_in_A)
290
291    def fail(s):
292        raise Exception("fail")
293
294    caught = False
295    try:
296        us.ever_set(fail)
297    except:
298        caught = True
299    assert caught
300
301
302# Check basic construction of spaces.
303#
304def test_space():
305    unit = isl.space.unit()
306    set_space = unit.add_named_tuple("A", 3)
307    map_space = set_space.add_named_tuple("B", 2)
308
309    set = isl.set.universe(set_space)
310    map = isl.map.universe(map_space)
311    assert set.is_equal(isl.set("{ A[*,*,*] }"))
312    assert map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }"))
313
314
315# Construct a simple schedule tree with an outer sequence node and
316# a single-dimensional band node in each branch, with one of them
317# marked coincident.
318#
319def construct_schedule_tree():
320    A = isl.union_set("{ A[i] : 0 <= i < 10 }")
321    B = isl.union_set("{ B[i] : 0 <= i < 20 }")
322
323    node = isl.schedule_node.from_domain(A.union(B))
324    node = node.child(0)
325
326    filters = isl.union_set_list(A).add(B)
327    node = node.insert_sequence(filters)
328
329    f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]")
330    node = node.child(0)
331    node = node.child(0)
332    node = node.insert_partial_schedule(f_A)
333    node = node.member_set_coincident(0, True)
334    node = node.ancestor(2)
335
336    f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]")
337    node = node.child(1)
338    node = node.child(0)
339    node = node.insert_partial_schedule(f_B)
340    node = node.ancestor(2)
341
342    return node.schedule()
343
344
345# Test basic schedule tree functionality.
346#
347# In particular, create a simple schedule tree and
348# - check that the root node is a domain node
349# - test map_descendant_bottom_up
350# - test foreach_descendant_top_down
351# - test every_descendant
352#
353def test_schedule_tree():
354    schedule = construct_schedule_tree()
355    root = schedule.root()
356
357    assert type(root) == isl.schedule_node_domain
358
359    count = [0]
360
361    def inc_count(node):
362        count[0] += 1
363        return node
364
365    root = root.map_descendant_bottom_up(inc_count)
366    assert count[0] == 8
367
368    def fail_map(node):
369        raise Exception("fail")
370        return node
371
372    caught = False
373    try:
374        root.map_descendant_bottom_up(fail_map)
375    except:
376        caught = True
377    assert caught
378
379    count = [0]
380
381    def inc_count(node):
382        count[0] += 1
383        return True
384
385    root.foreach_descendant_top_down(inc_count)
386    assert count[0] == 8
387
388    count = [0]
389
390    def inc_count(node):
391        count[0] += 1
392        return False
393
394    root.foreach_descendant_top_down(inc_count)
395    assert count[0] == 1
396
397    def is_not_domain(node):
398        return type(node) != isl.schedule_node_domain
399
400    assert root.child(0).every_descendant(is_not_domain)
401    assert not root.every_descendant(is_not_domain)
402
403    def fail(node):
404        raise Exception("fail")
405
406    caught = False
407    try:
408        root.every_descendant(fail)
409    except:
410        caught = True
411    assert caught
412
413    domain = root.domain()
414    filters = [isl.union_set("{}")]
415
416    def collect_filters(node):
417        if type(node) == isl.schedule_node_filter:
418            filters[0] = filters[0].union(node.filter())
419        return True
420
421    root.every_descendant(collect_filters)
422    assert domain.is_equal(filters[0])
423
424
425# Test marking band members for unrolling.
426# "schedule" is the schedule created by construct_schedule_tree.
427# It schedules two statements, with 10 and 20 instances, respectively.
428# Unrolling all band members therefore results in 30 at-domain calls
429# by the AST generator.
430#
431def test_ast_build_unroll(schedule):
432    root = schedule.root()
433
434    def mark_unroll(node):
435        if type(node) == isl.schedule_node_band:
436            node = node.member_set_ast_loop_unroll(0)
437        return node
438
439    root = root.map_descendant_bottom_up(mark_unroll)
440    schedule = root.schedule()
441
442    count_ast = [0]
443
444    def inc_count_ast(node, build):
445        count_ast[0] += 1
446        return node
447
448    build = isl.ast_build()
449    build = build.set_at_each_domain(inc_count_ast)
450    ast = build.node_from(schedule)
451    assert count_ast[0] == 30
452
453
454# Test basic AST generation from a schedule tree.
455#
456# In particular, create a simple schedule tree and
457# - generate an AST from the schedule tree
458# - test at_each_domain
459# - test unrolling
460#
461def test_ast_build():
462    schedule = construct_schedule_tree()
463
464    count_ast = [0]
465
466    def inc_count_ast(node, build):
467        count_ast[0] += 1
468        return node
469
470    build = isl.ast_build()
471    build_copy = build.set_at_each_domain(inc_count_ast)
472    ast = build.node_from(schedule)
473    assert count_ast[0] == 0
474    count_ast[0] = 0
475    ast = build_copy.node_from(schedule)
476    assert count_ast[0] == 2
477    build = build_copy
478    count_ast[0] = 0
479    ast = build.node_from(schedule)
480    assert count_ast[0] == 2
481
482    do_fail = True
483    count_ast_fail = [0]
484
485    def fail_inc_count_ast(node, build):
486        count_ast_fail[0] += 1
487        if do_fail:
488            raise Exception("fail")
489        return node
490
491    build = isl.ast_build()
492    build = build.set_at_each_domain(fail_inc_count_ast)
493    caught = False
494    try:
495        ast = build.node_from(schedule)
496    except:
497        caught = True
498    assert caught
499    assert count_ast_fail[0] > 0
500    build_copy = build
501    build_copy = build_copy.set_at_each_domain(inc_count_ast)
502    count_ast[0] = 0
503    ast = build_copy.node_from(schedule)
504    assert count_ast[0] == 2
505    count_ast_fail[0] = 0
506    do_fail = False
507    ast = build.node_from(schedule)
508    assert count_ast_fail[0] == 2
509
510    test_ast_build_unroll(schedule)
511
512
513# Test basic AST expression generation from an affine expression.
514#
515def test_ast_build_expr():
516    pa = isl.pw_aff("[n] -> { [n + 1] }")
517    build = isl.ast_build.from_context(pa.domain())
518
519    op = build.expr_from(pa)
520    assert type(op) == isl.ast_expr_op_add
521    assert op.n_arg() == 2
522
523
524# Test the isl Python interface
525#
526# This includes:
527#  - Object construction
528#  - Different parameter types
529#  - Different return types
530#  - isl.id.user
531#  - Foreach functions
532#  - Foreach SCC function
533#  - Every functions
534#  - Spaces
535#  - Schedule trees
536#  - AST generation
537#  - AST expression generation
538#
539test_constructors()
540test_parameters()
541test_return()
542test_user()
543test_foreach()
544test_foreach_scc()
545test_every()
546test_space()
547test_schedule_tree()
548test_ast_build()
549test_ast_build_expr()
550