xref: /llvm-project/mlir/test/python/ir/affine_expr.py (revision fd45dcca26d6031fcbaa907d8d6c0d9755b36699)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8    print("\nTEST:", f.__name__)
9    f()
10    gc.collect()
11    assert Context._get_live_count() == 0
12    return f
13
14
15# CHECK-LABEL: TEST: testAffineExprCapsule
16@run
17def testAffineExprCapsule():
18    with Context() as ctx:
19        affine_expr = AffineExpr.get_constant(42)
20
21    affine_expr_capsule = affine_expr._CAPIPtr
22    # CHECK: capsule object
23    # CHECK: mlir.ir.AffineExpr._CAPIPtr
24    print(affine_expr_capsule)
25
26    affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
27    assert affine_expr == affine_expr_2
28    assert affine_expr_2.context == ctx
29
30
31# CHECK-LABEL: TEST: testAffineExprEq
32@run
33def testAffineExprEq():
34    with Context():
35        a1 = AffineExpr.get_constant(42)
36        a2 = AffineExpr.get_constant(42)
37        a3 = AffineExpr.get_constant(43)
38        # CHECK: True
39        print(a1 == a1)
40        # CHECK: True
41        print(a1 == a2)
42        # CHECK: False
43        print(a1 == a3)
44        # CHECK: False
45        print(a1 is None)
46        # CHECK: False
47        print(a1 == "foo")
48
49
50# CHECK-LABEL: TEST: testAffineExprContext
51@run
52def testAffineExprContext():
53    with Context():
54        a1 = AffineExpr.get_constant(42)
55    with Context():
56        a2 = AffineExpr.get_constant(42)
57
58    # CHECK: False
59    print(a1 == a2)
60
61
62run(testAffineExprContext)
63
64
65# CHECK-LABEL: TEST: testAffineExprConstant
66@run
67def testAffineExprConstant():
68    with Context():
69        a1 = AffineExpr.get_constant(42)
70        # CHECK: 42
71        print(a1.value)
72        # CHECK: 42
73        print(a1)
74
75        a2 = AffineConstantExpr.get(42)
76        # CHECK: 42
77        print(a2.value)
78        # CHECK: 42
79        print(a2)
80
81        assert a1 == a2
82
83
84# CHECK-LABEL: TEST: testAffineExprDim
85@run
86def testAffineExprDim():
87    with Context():
88        d1 = AffineExpr.get_dim(1)
89        d11 = AffineDimExpr.get(1)
90        d2 = AffineDimExpr.get(2)
91
92        # CHECK: 1
93        print(d1.position)
94        # CHECK: d1
95        print(d1)
96
97        # CHECK: 2
98        print(d2.position)
99        # CHECK: d2
100        print(d2)
101
102        assert d1 == d11
103        assert d1 != d2
104
105
106# CHECK-LABEL: TEST: testAffineExprSymbol
107@run
108def testAffineExprSymbol():
109    with Context():
110        s1 = AffineExpr.get_symbol(1)
111        s11 = AffineSymbolExpr.get(1)
112        s2 = AffineSymbolExpr.get(2)
113
114        # CHECK: 1
115        print(s1.position)
116        # CHECK: s1
117        print(s1)
118
119        # CHECK: 2
120        print(s2.position)
121        # CHECK: s2
122        print(s2)
123
124        assert s1 == s11
125        assert s1 != s2
126
127
128# CHECK-LABEL: TEST: testAffineAddExpr
129@run
130def testAffineAddExpr():
131    with Context():
132        d1 = AffineDimExpr.get(1)
133        d2 = AffineDimExpr.get(2)
134        d12 = AffineExpr.get_add(d1, d2)
135        # CHECK: d1 + d2
136        print(d12)
137
138        d12op = d1 + d2
139        # CHECK: d1 + d2
140        print(d12op)
141
142        d1cst_op = d1 + 2
143        # CHECK: d1 + 2
144        print(d1cst_op)
145
146        d1cst_op2 = 2 + d1
147        # CHECK: d1 + 2
148        print(d1cst_op2)
149
150        assert d12 == d12op
151        assert d12.lhs == d1
152        assert d12.rhs == d2
153
154
155# CHECK-LABEL: TEST: testAffineMulExpr
156@run
157def testAffineMulExpr():
158    with Context():
159        d1 = AffineDimExpr.get(1)
160        c2 = AffineConstantExpr.get(2)
161        expr = AffineExpr.get_mul(d1, c2)
162        # CHECK: d1 * 2
163        print(expr)
164
165        # CHECK: d1 * 2
166        op = d1 * c2
167        print(op)
168
169        # CHECK: d1 * 2
170        op_cst = d1 * 2
171        print(op_cst)
172
173        # CHECK: d1 * 2
174        op_cst2 = 2 * d1
175        print(op_cst2)
176
177        assert expr == op
178        assert expr == op_cst
179        assert expr.lhs == d1
180        assert expr.rhs == c2
181
182
183# CHECK-LABEL: TEST: testAffineModExpr
184@run
185def testAffineModExpr():
186    with Context():
187        d1 = AffineDimExpr.get(1)
188        c2 = AffineConstantExpr.get(2)
189        expr = AffineExpr.get_mod(d1, c2)
190        # CHECK: d1 mod 2
191        print(expr)
192
193        # CHECK: d1 mod 2
194        op = d1 % c2
195        print(op)
196
197        # CHECK: d1 mod 2
198        op_cst = d1 % 2
199        print(op_cst)
200
201        # CHECK: 2 mod d1
202        print(2 % d1)
203
204        assert expr == op
205        assert expr == op_cst
206        assert expr.lhs == d1
207        assert expr.rhs == c2
208
209        expr2 = AffineExpr.get_mod(c2, d1)
210        expr3 = AffineExpr.get_mod(2, d1)
211        expr4 = AffineExpr.get_mod(d1, 2)
212
213        # CHECK: 2 mod d1
214        print(expr2)
215        # CHECK: 2 mod d1
216        print(expr3)
217        # CHECK: d1 mod 2
218        print(expr4)
219
220        assert expr2 == expr3
221        assert expr4 == expr
222
223
224# CHECK-LABEL: TEST: testAffineFloorDivExpr
225@run
226def testAffineFloorDivExpr():
227    with Context():
228        d1 = AffineDimExpr.get(1)
229        c2 = AffineConstantExpr.get(2)
230        expr = AffineExpr.get_floor_div(d1, c2)
231        # CHECK: d1 floordiv 2
232        print(expr)
233
234        assert expr.lhs == d1
235        assert expr.rhs == c2
236
237        expr2 = AffineExpr.get_floor_div(c2, d1)
238        expr3 = AffineExpr.get_floor_div(2, d1)
239        expr4 = AffineExpr.get_floor_div(d1, 2)
240
241        # CHECK: 2 floordiv d1
242        print(expr2)
243        # CHECK: 2 floordiv d1
244        print(expr3)
245        # CHECK: d1 floordiv 2
246        print(expr4)
247
248        assert expr2 == expr3
249        assert expr4 == expr
250
251
252# CHECK-LABEL: TEST: testAffineCeilDivExpr
253@run
254def testAffineCeilDivExpr():
255    with Context():
256        d1 = AffineDimExpr.get(1)
257        c2 = AffineConstantExpr.get(2)
258        expr = AffineExpr.get_ceil_div(d1, c2)
259        # CHECK: d1 ceildiv 2
260        print(expr)
261
262        assert expr.lhs == d1
263        assert expr.rhs == c2
264
265        expr2 = AffineExpr.get_ceil_div(c2, d1)
266        expr3 = AffineExpr.get_ceil_div(2, d1)
267        expr4 = AffineExpr.get_ceil_div(d1, 2)
268
269        # CHECK: 2 ceildiv d1
270        print(expr2)
271        # CHECK: 2 ceildiv d1
272        print(expr3)
273        # CHECK: d1 ceildiv 2
274        print(expr4)
275
276        assert expr2 == expr3
277        assert expr4 == expr
278
279
280# CHECK-LABEL: TEST: testAffineExprSub
281@run
282def testAffineExprSub():
283    with Context():
284        d1 = AffineDimExpr.get(1)
285        d2 = AffineDimExpr.get(2)
286        expr = d1 - d2
287        # CHECK: d1 - d2
288        print(expr)
289
290        assert expr.lhs == d1
291        rhs = AffineMulExpr(expr.rhs)
292        # CHECK: d2
293        print(rhs.lhs)
294        # CHECK: -1
295        print(rhs.rhs)
296
297        # CHECK: d1 - 42
298        print(d1 - 42)
299        # CHECK: -d1 + 42
300        print(42 - d1)
301
302        c42 = AffineConstantExpr.get(42)
303        assert d1 - 42 == d1 - c42
304        assert 42 - d1 == c42 - d1
305
306
307# CHECK-LABEL: TEST: testClassHierarchy
308@run
309def testClassHierarchy():
310    with Context():
311        d1 = AffineDimExpr.get(1)
312        c2 = AffineConstantExpr.get(2)
313        add = AffineAddExpr.get(d1, c2)
314        mul = AffineMulExpr.get(d1, c2)
315        mod = AffineModExpr.get(d1, c2)
316        floor_div = AffineFloorDivExpr.get(d1, c2)
317        ceil_div = AffineCeilDivExpr.get(d1, c2)
318
319        # CHECK: False
320        print(isinstance(d1, AffineBinaryExpr))
321        # CHECK: False
322        print(isinstance(c2, AffineBinaryExpr))
323        # CHECK: True
324        print(isinstance(add, AffineBinaryExpr))
325        # CHECK: True
326        print(isinstance(mul, AffineBinaryExpr))
327        # CHECK: True
328        print(isinstance(mod, AffineBinaryExpr))
329        # CHECK: True
330        print(isinstance(floor_div, AffineBinaryExpr))
331        # CHECK: True
332        print(isinstance(ceil_div, AffineBinaryExpr))
333
334        try:
335            AffineBinaryExpr(d1)
336        except ValueError as e:
337            # CHECK: Cannot cast affine expression to AffineBinaryExpr
338            print(e)
339
340        try:
341            AffineBinaryExpr(c2)
342        except ValueError as e:
343            # CHECK: Cannot cast affine expression to AffineBinaryExpr
344            print(e)
345
346
347# CHECK-LABEL: TEST: testIsInstance
348@run
349def testIsInstance():
350    with Context():
351        d1 = AffineDimExpr.get(1)
352        c2 = AffineConstantExpr.get(2)
353        add = AffineAddExpr.get(d1, c2)
354        mul = AffineMulExpr.get(d1, c2)
355
356        # CHECK: True
357        print(AffineDimExpr.isinstance(d1))
358        # CHECK: False
359        print(AffineConstantExpr.isinstance(d1))
360        # CHECK: True
361        print(AffineConstantExpr.isinstance(c2))
362        # CHECK: False
363        print(AffineMulExpr.isinstance(c2))
364        # CHECK: True
365        print(AffineAddExpr.isinstance(add))
366        # CHECK: False
367        print(AffineMulExpr.isinstance(add))
368        # CHECK: True
369        print(AffineMulExpr.isinstance(mul))
370        # CHECK: False
371        print(AffineAddExpr.isinstance(mul))
372
373
374# CHECK-LABEL: TEST: testCompose
375@run
376def testCompose():
377    with Context():
378        # d0 + d2.
379        expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2))
380
381        # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)
382        map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1))
383        map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0))
384        map3 = AffineAddExpr.get(
385            AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)),
386            AffineDimExpr.get(2),
387        )
388        map = AffineMap.get(3, 2, [map1, map2, map3])
389
390        # CHECK: d0 + s1 + d0 + d1 + d2
391        print(expr.compose(map))
392
393
394# CHECK-LABEL: TEST: testHash
395@run
396def testHash():
397    with Context():
398        d0 = AffineDimExpr.get(0)
399        s1 = AffineSymbolExpr.get(1)
400        assert hash(d0) == hash(AffineDimExpr.get(0))
401        assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1))
402
403        dictionary = dict()
404        dictionary[d0] = 0
405        dictionary[s1] = 1
406        assert d0 in dictionary
407        assert s1 in dictionary
408