xref: /llvm-project/mlir/test/python/ir/affine_map.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8    print("\nTEST:", f.__name__)
9    f()
10    gc.collect()
11    assert Context._get_live_count() == 0
12    return f
13
14
15# CHECK-LABEL: TEST: testAffineMapCapsule
16@run
17def testAffineMapCapsule():
18    with Context() as ctx:
19        am1 = AffineMap.get_empty(ctx)
20    # CHECK: mlir.ir.AffineMap._CAPIPtr
21    affine_map_capsule = am1._CAPIPtr
22    print(affine_map_capsule)
23    am2 = AffineMap._CAPICreate(affine_map_capsule)
24    assert am2 == am1
25    assert am2.context is ctx
26
27
28# CHECK-LABEL: TEST: testAffineMapGet
29@run
30def testAffineMapGet():
31    with Context() as ctx:
32        d0 = AffineDimExpr.get(0)
33        d1 = AffineDimExpr.get(1)
34        c2 = AffineConstantExpr.get(2)
35
36        # CHECK: (d0, d1)[s0, s1, s2] -> ()
37        map0 = AffineMap.get(2, 3, [])
38        print(map0)
39
40        # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
41        map1 = AffineMap.get(2, 3, [d1, c2])
42        print(map1)
43
44        # CHECK: () -> (2)
45        map2 = AffineMap.get(0, 0, [c2])
46        print(map2)
47
48        # CHECK: (d0, d1) -> (d0, d1)
49        map3 = AffineMap.get(2, 0, [d0, d1])
50        print(map3)
51
52        # CHECK: (d0, d1) -> (d1)
53        map4 = AffineMap.get(2, 0, [d1])
54        print(map4)
55
56        # CHECK: (d0, d1, d2) -> (d2, d0, d1)
57        map5 = AffineMap.get_permutation([2, 0, 1])
58        print(map5)
59
60        assert map1 == AffineMap.get(2, 3, [d1, c2])
61        assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
62        assert map2 == AffineMap.get_constant(2)
63        assert map3 == AffineMap.get_identity(2)
64        assert map4 == AffineMap.get_minor_identity(2, 1)
65
66        try:
67            AffineMap.get(1, 1, [1])
68        except RuntimeError as e:
69            # CHECK: Invalid expression when attempting to create an AffineMap
70            print(e)
71
72        try:
73            AffineMap.get(1, 1, [None])
74        except RuntimeError as e:
75            # CHECK: Invalid expression (None?) when attempting to create an AffineMap
76            print(e)
77
78        try:
79            AffineMap.get_permutation([1, 0, 1])
80        except RuntimeError as e:
81            # CHECK: Invalid permutation when attempting to create an AffineMap
82            print(e)
83
84        try:
85            map3.get_submap([42])
86        except ValueError as e:
87            # CHECK: result position out of bounds
88            print(e)
89
90        try:
91            map3.get_minor_submap(42)
92        except ValueError as e:
93            # CHECK: number of results out of bounds
94            print(e)
95
96        try:
97            map3.get_major_submap(42)
98        except ValueError as e:
99            # CHECK: number of results out of bounds
100            print(e)
101
102
103# CHECK-LABEL: TEST: testAffineMapDerive
104@run
105def testAffineMapDerive():
106    with Context() as ctx:
107        map5 = AffineMap.get_identity(5)
108
109        # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
110        map123 = map5.get_submap([1, 2, 3])
111        print(map123)
112
113        # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
114        map01 = map5.get_major_submap(2)
115        print(map01)
116
117        # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
118        map34 = map5.get_minor_submap(2)
119        print(map34)
120
121
122# CHECK-LABEL: TEST: testAffineMapProperties
123@run
124def testAffineMapProperties():
125    with Context():
126        d0 = AffineDimExpr.get(0)
127        d1 = AffineDimExpr.get(1)
128        d2 = AffineDimExpr.get(2)
129        map1 = AffineMap.get(3, 0, [d2, d0])
130        map2 = AffineMap.get(3, 0, [d2, d0, d1])
131        map3 = AffineMap.get(3, 1, [d2, d0, d1])
132        # CHECK: False
133        print(map1.is_permutation)
134        # CHECK: True
135        print(map1.is_projected_permutation)
136        # CHECK: True
137        print(map2.is_permutation)
138        # CHECK: True
139        print(map2.is_projected_permutation)
140        # CHECK: False
141        print(map3.is_permutation)
142        # CHECK: False
143        print(map3.is_projected_permutation)
144
145
146# CHECK-LABEL: TEST: testAffineMapExprs
147@run
148def testAffineMapExprs():
149    with Context():
150        d0 = AffineDimExpr.get(0)
151        d1 = AffineDimExpr.get(1)
152        d2 = AffineDimExpr.get(2)
153        map3 = AffineMap.get(3, 1, [d2, d0, d1])
154
155        # CHECK: 3
156        print(map3.n_dims)
157        # CHECK: 4
158        print(map3.n_inputs)
159        # CHECK: 1
160        print(map3.n_symbols)
161        assert map3.n_inputs == map3.n_dims + map3.n_symbols
162
163        # CHECK: 3
164        print(len(map3.results))
165        for expr in map3.results:
166            # CHECK: d2
167            # CHECK: d0
168            # CHECK: d1
169            print(expr)
170        for expr in map3.results[-1:-4:-1]:
171            # CHECK: d1
172            # CHECK: d0
173            # CHECK: d2
174            print(expr)
175        assert list(map3.results) == [d2, d0, d1]
176
177
178# CHECK-LABEL: TEST: testCompressUnusedSymbols
179@run
180def testCompressUnusedSymbols():
181    with Context() as ctx:
182        d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
183        s0, s1, s2 = (
184            AffineSymbolExpr.get(0),
185            AffineSymbolExpr.get(1),
186            AffineSymbolExpr.get(2),
187        )
188        maps = [
189            AffineMap.get(3, 3, [d2, d0, d1]),
190            AffineMap.get(3, 3, [d2, d0 + s2, d1]),
191            AffineMap.get(3, 3, [d1, d2, d0]),
192        ]
193
194        compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
195
196        #      CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
197        # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
198        # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
199        print(maps)
200
201        #      CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
202        # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
203        # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
204        print(compressed_maps)
205
206
207# CHECK-LABEL: TEST: testReplace
208@run
209def testReplace():
210    with Context() as ctx:
211        d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
212        s0, s1, s2 = (
213            AffineSymbolExpr.get(0),
214            AffineSymbolExpr.get(1),
215            AffineSymbolExpr.get(2),
216        )
217        map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
218
219        replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
220        replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
221        replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
222
223        # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
224        print(replace0)
225
226        # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
227        print(replace1)
228
229        # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
230        print(replace3)
231
232
233# CHECK-LABEL: TEST: testHash
234@run
235def testHash():
236    with Context():
237        d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
238        m1 = AffineMap.get(2, 0, [d0, d1])
239        m2 = AffineMap.get(2, 0, [d1, d0])
240        assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
241
242        dictionary = dict()
243        dictionary[m1] = 1
244        dictionary[m2] = 2
245        assert m1 in dictionary
246