xref: /llvm-project/mlir/test/python/ir/integer_set.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8    print("\nTEST:", f.__name__)
9    f()
10    gc.collect()
11    assert Context._get_live_count() == 0
12    return f
13
14
15# CHECK-LABEL: TEST: testIntegerSetCapsule
16@run
17def testIntegerSetCapsule():
18    with Context() as ctx:
19        is1 = IntegerSet.get_empty(1, 1, ctx)
20    capsule = is1._CAPIPtr
21    # CHECK: mlir.ir.IntegerSet._CAPIPtr
22    print(capsule)
23    is2 = IntegerSet._CAPICreate(capsule)
24    assert is1 == is2
25    assert is2.context is ctx
26
27
28# CHECK-LABEL: TEST: testIntegerSetGet
29@run
30def testIntegerSetGet():
31    with Context():
32        d0 = AffineDimExpr.get(0)
33        d1 = AffineDimExpr.get(1)
34        s0 = AffineSymbolExpr.get(0)
35        c42 = AffineConstantExpr.get(42)
36
37        # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
38        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
39        print(set0)
40
41        # CHECK: (d0)[s0] : (1 == 0)
42        set1 = IntegerSet.get_empty(1, 1)
43        print(set1)
44
45        # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
46        set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
47        print(set2)
48
49        try:
50            IntegerSet.get(2, 1, [], [])
51        except ValueError as e:
52            # CHECK: Expected non-empty list of constraints
53            print(e)
54
55        try:
56            IntegerSet.get(2, 1, [d0 - d1], [True, False])
57        except ValueError as e:
58            # CHECK: Expected the number of constraints to match that of equality flags
59            print(e)
60
61        try:
62            IntegerSet.get(2, 1, [0], [True])
63        except RuntimeError as e:
64            # CHECK: Invalid expression when attempting to create an IntegerSet
65            print(e)
66
67        try:
68            IntegerSet.get(2, 1, [None], [True])
69        except RuntimeError as e:
70            # CHECK: Invalid expression (None?) when attempting to create an IntegerSet
71            print(e)
72
73        try:
74            set0.get_replaced([d0], [s0], 1, 1)
75        except ValueError as e:
76            # CHECK: Expected the number of dimension replacement expressions to match that of dimensions
77            print(e)
78
79        try:
80            set0.get_replaced([d0, d1], [s0, s0], 1, 1)
81        except ValueError as e:
82            # CHECK: Expected the number of symbol replacement expressions to match that of symbols
83            print(e)
84
85        try:
86            set0.get_replaced([d0, 1], [s0], 1, 1)
87        except RuntimeError as e:
88            # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
89            print(e)
90
91        try:
92            set0.get_replaced([d0, d1], [None], 1, 1)
93        except RuntimeError as e:
94            # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
95            print(e)
96
97
98# CHECK-LABEL: TEST: testIntegerSetProperties
99@run
100def testIntegerSetProperties():
101    with Context():
102        d0 = AffineDimExpr.get(0)
103        d1 = AffineDimExpr.get(1)
104        s0 = AffineSymbolExpr.get(0)
105        c42 = AffineConstantExpr.get(42)
106
107        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
108        # CHECK: 2
109        print(set0.n_dims)
110        # CHECK: 1
111        print(set0.n_symbols)
112        # CHECK: 3
113        print(set0.n_inputs)
114        # CHECK: 1
115        print(set0.n_equalities)
116        # CHECK: 2
117        print(set0.n_inequalities)
118
119        # CHECK: 3
120        print(len(set0.constraints))
121
122        # CHECK-DAG: d0 - d1 == 0
123        # CHECK-DAG: s0 - 42 >= 0
124        # CHECK-DAG: -d0 + s0 >= 0
125        for cstr in set0.constraints:
126            print(cstr.expr, end="")
127            print(" == 0" if cstr.is_eq else " >= 0")
128
129
130# TODO-LABEL: TEST: testHash
131@run
132def testHash():
133    with Context():
134        d0 = AffineDimExpr.get(0)
135        d1 = AffineDimExpr.get(1)
136        set = IntegerSet.get(2, 0, [d0 + d1], [True])
137
138        assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True]))
139
140        dictionary = dict()
141        dictionary[set] = 42
142        assert set in dictionary
143