1# Copyright 2016-2017 Tobias Grosser 2# 3# Use of this software is governed by the MIT license 4# 5# Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich 6 7import sys 8import isl 9 10# Test that isl objects can be constructed. 11# 12# This tests: 13# - construction from a string 14# - construction from an integer 15# - static constructor without a parameter 16# - conversion construction 17# - construction of empty union set 18# 19# The tests to construct from integers and strings cover functionality that 20# is also tested in the parameter type tests, but here the presence of 21# multiple overloaded constructors and overload resolution is tested. 22# 23def test_constructors(): 24 zero1 = isl.val("0") 25 assert zero1.is_zero() 26 27 zero2 = isl.val(0) 28 assert zero2.is_zero() 29 30 zero3 = isl.val.zero() 31 assert zero3.is_zero() 32 33 bs = isl.basic_set("{ [1] }") 34 result = isl.set("{ [1] }") 35 s = isl.set(bs) 36 assert s.is_equal(result) 37 38 us = isl.union_set("{ A[1]; B[2, 3] }") 39 empty = isl.union_set.empty() 40 assert us.is_equal(us.union(empty)) 41 42 43# Test integer function parameters for a particular integer value. 44# 45def test_int(i): 46 val_int = isl.val(i) 47 val_str = isl.val(str(i)) 48 assert val_int.eq(val_str) 49 50 51# Test integer function parameters. 52# 53# Verify that extreme values and zero work. 54# 55def test_parameters_int(): 56 test_int(sys.maxsize) 57 test_int(-sys.maxsize - 1) 58 test_int(0) 59 60 61# Test isl objects parameters. 62# 63# Verify that isl objects can be passed as lvalue and rvalue parameters. 64# Also verify that isl object parameters are automatically type converted if 65# there is an inheritance relation. Finally, test function calls without 66# any additional parameters, apart from the isl object on which 67# the method is called. 68# 69def test_parameters_obj(): 70 a = isl.set("{ [0] }") 71 b = isl.set("{ [1] }") 72 c = isl.set("{ [2] }") 73 expected = isl.set("{ [i] : 0 <= i <= 2 }") 74 75 tmp = a.union(b) 76 res_lvalue_param = tmp.union(c) 77 assert res_lvalue_param.is_equal(expected) 78 79 res_rvalue_param = a.union(b).union(c) 80 assert res_rvalue_param.is_equal(expected) 81 82 a2 = isl.basic_set("{ [0] }") 83 assert a.is_equal(a2) 84 85 two = isl.val(2) 86 half = isl.val("1/2") 87 res_only_this_param = two.inv() 88 assert res_only_this_param.eq(half) 89 90 91# Test different kinds of parameters to be passed to functions. 92# 93# This includes integer and isl object parameters. 94# 95def test_parameters(): 96 test_parameters_int() 97 test_parameters_obj() 98 99 100# Test that isl objects are returned correctly. 101# 102# This only tests that after combining two objects, the result is successfully 103# returned. 104# 105def test_return_obj(): 106 one = isl.val("1") 107 two = isl.val("2") 108 three = isl.val("3") 109 110 res = one.add(two) 111 112 assert res.eq(three) 113 114 115# Test that integer values are returned correctly. 116# 117def test_return_int(): 118 one = isl.val("1") 119 neg_one = isl.val("-1") 120 zero = isl.val("0") 121 122 assert one.sgn() > 0 123 assert neg_one.sgn() < 0 124 assert zero.sgn() == 0 125 126 127# Test that isl_bool values are returned correctly. 128# 129# In particular, check the conversion to bool in case of true and false. 130# 131def test_return_bool(): 132 empty = isl.set("{ : false }") 133 univ = isl.set("{ : }") 134 135 b_true = empty.is_empty() 136 b_false = univ.is_empty() 137 138 assert b_true 139 assert not b_false 140 141 142# Test that strings are returned correctly. 143# Do so by calling overloaded isl.ast_build.from_expr methods. 144# 145def test_return_string(): 146 context = isl.set("[n] -> { : }") 147 build = isl.ast_build.from_context(context) 148 pw_aff = isl.pw_aff("[n] -> { [n] }") 149 set = isl.set("[n] -> { : n >= 0 }") 150 151 expr = build.expr_from(pw_aff) 152 expected_string = "n" 153 assert expected_string == expr.to_C_str() 154 155 expr = build.expr_from(set) 156 expected_string = "n >= 0" 157 assert expected_string == expr.to_C_str() 158 159 160# Test that return values are handled correctly. 161# 162# Test that isl objects, integers, boolean values, and strings are 163# returned correctly. 164# 165def test_return(): 166 test_return_obj() 167 test_return_int() 168 test_return_bool() 169 test_return_string() 170 171 172# A class that is used to test isl.id.user. 173# 174class S: 175 def __init__(self): 176 self.value = 42 177 178 179# Test isl.id.user. 180# 181# In particular, check that the object attached to an identifier 182# can be retrieved again. 183# 184def test_user(): 185 id = isl.id("test", 5) 186 id2 = isl.id("test2") 187 id3 = isl.id("S", S()) 188 assert id.user() == 5, f"unexpected user object {id.user()}" 189 assert id2.user() is None, f"unexpected user object {id2.user()}" 190 s = id3.user() 191 assert isinstance(s, S), f"unexpected user object {s}" 192 assert s.value == 42, f"unexpected user object {s}" 193 194 195# Test that foreach functions are modeled correctly. 196# 197# Verify that closures are correctly called as callback of a 'foreach' 198# function and that variables captured by the closure work correctly. Also 199# check that the foreach function handles exceptions thrown from 200# the closure and that it propagates the exception. 201# 202def test_foreach(): 203 s = isl.set("{ [0]; [1]; [2] }") 204 205 list = [] 206 207 def add(bs): 208 list.append(bs) 209 210 s.foreach_basic_set(add) 211 212 assert len(list) == 3 213 assert list[0].is_subset(s) 214 assert list[1].is_subset(s) 215 assert list[2].is_subset(s) 216 assert not list[0].is_equal(list[1]) 217 assert not list[0].is_equal(list[2]) 218 assert not list[1].is_equal(list[2]) 219 220 def fail(bs): 221 raise Exception("fail") 222 223 caught = False 224 try: 225 s.foreach_basic_set(fail) 226 except: 227 caught = True 228 assert caught 229 230 231# Test the functionality of "foreach_scc" functions. 232# 233# In particular, test it on a list of elements that can be completely sorted 234# but where two of the elements ("a" and "b") are incomparable. 235# 236def test_foreach_scc(): 237 list = isl.id_list(3) 238 sorted = [isl.id_list(3)] 239 data = { 240 "a": isl.map("{ [0] -> [1] }"), 241 "b": isl.map("{ [1] -> [0] }"), 242 "c": isl.map("{ [i = 0:1] -> [i] }"), 243 } 244 for k, v in data.items(): 245 list = list.add(k) 246 id = data["a"].space().domain().identity_multi_pw_aff_on_domain() 247 248 def follows(a, b): 249 map = data[b.name()].apply_domain(data[a.name()]) 250 return not map.lex_ge_at(id).is_empty() 251 252 def add_single(scc): 253 assert scc.size() == 1 254 sorted[0] = sorted[0].concat(scc) 255 256 list.foreach_scc(follows, add_single) 257 assert sorted[0].size() == 3 258 assert sorted[0].at(0).name() == "b" 259 assert sorted[0].at(1).name() == "c" 260 assert sorted[0].at(2).name() == "a" 261 262 263# Test the functionality of "every" functions. 264# 265# In particular, test the generic functionality and 266# test that exceptions are properly propagated. 267# 268def test_every(): 269 us = isl.union_set("{ A[i]; B[j] }") 270 271 def is_empty(s): 272 return s.is_empty() 273 274 assert not us.every_set(is_empty) 275 276 def is_non_empty(s): 277 return not s.is_empty() 278 279 assert us.every_set(is_non_empty) 280 281 def in_A(s): 282 return s.is_subset(isl.set("{ A[x] }")) 283 284 assert not us.every_set(in_A) 285 286 def not_in_A(s): 287 return not s.is_subset(isl.set("{ A[x] }")) 288 289 assert not us.every_set(not_in_A) 290 291 def fail(s): 292 raise Exception("fail") 293 294 caught = False 295 try: 296 us.ever_set(fail) 297 except: 298 caught = True 299 assert caught 300 301 302# Check basic construction of spaces. 303# 304def test_space(): 305 unit = isl.space.unit() 306 set_space = unit.add_named_tuple("A", 3) 307 map_space = set_space.add_named_tuple("B", 2) 308 309 set = isl.set.universe(set_space) 310 map = isl.map.universe(map_space) 311 assert set.is_equal(isl.set("{ A[*,*,*] }")) 312 assert map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }")) 313 314 315# Construct a simple schedule tree with an outer sequence node and 316# a single-dimensional band node in each branch, with one of them 317# marked coincident. 318# 319def construct_schedule_tree(): 320 A = isl.union_set("{ A[i] : 0 <= i < 10 }") 321 B = isl.union_set("{ B[i] : 0 <= i < 20 }") 322 323 node = isl.schedule_node.from_domain(A.union(B)) 324 node = node.child(0) 325 326 filters = isl.union_set_list(A).add(B) 327 node = node.insert_sequence(filters) 328 329 f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]") 330 node = node.child(0) 331 node = node.child(0) 332 node = node.insert_partial_schedule(f_A) 333 node = node.member_set_coincident(0, True) 334 node = node.ancestor(2) 335 336 f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]") 337 node = node.child(1) 338 node = node.child(0) 339 node = node.insert_partial_schedule(f_B) 340 node = node.ancestor(2) 341 342 return node.schedule() 343 344 345# Test basic schedule tree functionality. 346# 347# In particular, create a simple schedule tree and 348# - check that the root node is a domain node 349# - test map_descendant_bottom_up 350# - test foreach_descendant_top_down 351# - test every_descendant 352# 353def test_schedule_tree(): 354 schedule = construct_schedule_tree() 355 root = schedule.root() 356 357 assert type(root) == isl.schedule_node_domain 358 359 count = [0] 360 361 def inc_count(node): 362 count[0] += 1 363 return node 364 365 root = root.map_descendant_bottom_up(inc_count) 366 assert count[0] == 8 367 368 def fail_map(node): 369 raise Exception("fail") 370 return node 371 372 caught = False 373 try: 374 root.map_descendant_bottom_up(fail_map) 375 except: 376 caught = True 377 assert caught 378 379 count = [0] 380 381 def inc_count(node): 382 count[0] += 1 383 return True 384 385 root.foreach_descendant_top_down(inc_count) 386 assert count[0] == 8 387 388 count = [0] 389 390 def inc_count(node): 391 count[0] += 1 392 return False 393 394 root.foreach_descendant_top_down(inc_count) 395 assert count[0] == 1 396 397 def is_not_domain(node): 398 return type(node) != isl.schedule_node_domain 399 400 assert root.child(0).every_descendant(is_not_domain) 401 assert not root.every_descendant(is_not_domain) 402 403 def fail(node): 404 raise Exception("fail") 405 406 caught = False 407 try: 408 root.every_descendant(fail) 409 except: 410 caught = True 411 assert caught 412 413 domain = root.domain() 414 filters = [isl.union_set("{}")] 415 416 def collect_filters(node): 417 if type(node) == isl.schedule_node_filter: 418 filters[0] = filters[0].union(node.filter()) 419 return True 420 421 root.every_descendant(collect_filters) 422 assert domain.is_equal(filters[0]) 423 424 425# Test marking band members for unrolling. 426# "schedule" is the schedule created by construct_schedule_tree. 427# It schedules two statements, with 10 and 20 instances, respectively. 428# Unrolling all band members therefore results in 30 at-domain calls 429# by the AST generator. 430# 431def test_ast_build_unroll(schedule): 432 root = schedule.root() 433 434 def mark_unroll(node): 435 if type(node) == isl.schedule_node_band: 436 node = node.member_set_ast_loop_unroll(0) 437 return node 438 439 root = root.map_descendant_bottom_up(mark_unroll) 440 schedule = root.schedule() 441 442 count_ast = [0] 443 444 def inc_count_ast(node, build): 445 count_ast[0] += 1 446 return node 447 448 build = isl.ast_build() 449 build = build.set_at_each_domain(inc_count_ast) 450 ast = build.node_from(schedule) 451 assert count_ast[0] == 30 452 453 454# Test basic AST generation from a schedule tree. 455# 456# In particular, create a simple schedule tree and 457# - generate an AST from the schedule tree 458# - test at_each_domain 459# - test unrolling 460# 461def test_ast_build(): 462 schedule = construct_schedule_tree() 463 464 count_ast = [0] 465 466 def inc_count_ast(node, build): 467 count_ast[0] += 1 468 return node 469 470 build = isl.ast_build() 471 build_copy = build.set_at_each_domain(inc_count_ast) 472 ast = build.node_from(schedule) 473 assert count_ast[0] == 0 474 count_ast[0] = 0 475 ast = build_copy.node_from(schedule) 476 assert count_ast[0] == 2 477 build = build_copy 478 count_ast[0] = 0 479 ast = build.node_from(schedule) 480 assert count_ast[0] == 2 481 482 do_fail = True 483 count_ast_fail = [0] 484 485 def fail_inc_count_ast(node, build): 486 count_ast_fail[0] += 1 487 if do_fail: 488 raise Exception("fail") 489 return node 490 491 build = isl.ast_build() 492 build = build.set_at_each_domain(fail_inc_count_ast) 493 caught = False 494 try: 495 ast = build.node_from(schedule) 496 except: 497 caught = True 498 assert caught 499 assert count_ast_fail[0] > 0 500 build_copy = build 501 build_copy = build_copy.set_at_each_domain(inc_count_ast) 502 count_ast[0] = 0 503 ast = build_copy.node_from(schedule) 504 assert count_ast[0] == 2 505 count_ast_fail[0] = 0 506 do_fail = False 507 ast = build.node_from(schedule) 508 assert count_ast_fail[0] == 2 509 510 test_ast_build_unroll(schedule) 511 512 513# Test basic AST expression generation from an affine expression. 514# 515def test_ast_build_expr(): 516 pa = isl.pw_aff("[n] -> { [n + 1] }") 517 build = isl.ast_build.from_context(pa.domain()) 518 519 op = build.expr_from(pa) 520 assert type(op) == isl.ast_expr_op_add 521 assert op.n_arg() == 2 522 523 524# Test the isl Python interface 525# 526# This includes: 527# - Object construction 528# - Different parameter types 529# - Different return types 530# - isl.id.user 531# - Foreach functions 532# - Foreach SCC function 533# - Every functions 534# - Spaces 535# - Schedule trees 536# - AST generation 537# - AST expression generation 538# 539test_constructors() 540test_parameters() 541test_return() 542test_user() 543test_foreach() 544test_foreach_scc() 545test_every() 546test_space() 547test_schedule_tree() 548test_ast_build() 549test_ast_build_expr() 550