xref: /llvm-project/mlir/test/python/ir/integer_set.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
39f3f6d7bSStella Laurenzoimport gc
49f3f6d7bSStella Laurenzofrom mlir.ir import *
59f3f6d7bSStella Laurenzo
6*f9008e63STobias Hieta
79f3f6d7bSStella Laurenzodef run(f):
89f3f6d7bSStella Laurenzo    print("\nTEST:", f.__name__)
99f3f6d7bSStella Laurenzo    f()
109f3f6d7bSStella Laurenzo    gc.collect()
119f3f6d7bSStella Laurenzo    assert Context._get_live_count() == 0
12fc7594ccSAlex Zinenko    return f
139f3f6d7bSStella Laurenzo
149f3f6d7bSStella Laurenzo
159f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testIntegerSetCapsule
16fc7594ccSAlex Zinenko@run
179f3f6d7bSStella Laurenzodef testIntegerSetCapsule():
189f3f6d7bSStella Laurenzo    with Context() as ctx:
199f3f6d7bSStella Laurenzo        is1 = IntegerSet.get_empty(1, 1, ctx)
209f3f6d7bSStella Laurenzo    capsule = is1._CAPIPtr
219f3f6d7bSStella Laurenzo    # CHECK: mlir.ir.IntegerSet._CAPIPtr
229f3f6d7bSStella Laurenzo    print(capsule)
239f3f6d7bSStella Laurenzo    is2 = IntegerSet._CAPICreate(capsule)
249f3f6d7bSStella Laurenzo    assert is1 == is2
259f3f6d7bSStella Laurenzo    assert is2.context is ctx
269f3f6d7bSStella Laurenzo
279f3f6d7bSStella Laurenzo
289f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testIntegerSetGet
29fc7594ccSAlex Zinenko@run
309f3f6d7bSStella Laurenzodef testIntegerSetGet():
319f3f6d7bSStella Laurenzo    with Context():
329f3f6d7bSStella Laurenzo        d0 = AffineDimExpr.get(0)
339f3f6d7bSStella Laurenzo        d1 = AffineDimExpr.get(1)
349f3f6d7bSStella Laurenzo        s0 = AffineSymbolExpr.get(0)
359f3f6d7bSStella Laurenzo        c42 = AffineConstantExpr.get(42)
369f3f6d7bSStella Laurenzo
379f3f6d7bSStella Laurenzo        # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
389f3f6d7bSStella Laurenzo        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
399f3f6d7bSStella Laurenzo        print(set0)
409f3f6d7bSStella Laurenzo
419f3f6d7bSStella Laurenzo        # CHECK: (d0)[s0] : (1 == 0)
429f3f6d7bSStella Laurenzo        set1 = IntegerSet.get_empty(1, 1)
439f3f6d7bSStella Laurenzo        print(set1)
449f3f6d7bSStella Laurenzo
459f3f6d7bSStella Laurenzo        # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
469f3f6d7bSStella Laurenzo        set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
479f3f6d7bSStella Laurenzo        print(set2)
489f3f6d7bSStella Laurenzo
499f3f6d7bSStella Laurenzo        try:
509f3f6d7bSStella Laurenzo            IntegerSet.get(2, 1, [], [])
519f3f6d7bSStella Laurenzo        except ValueError as e:
529f3f6d7bSStella Laurenzo            # CHECK: Expected non-empty list of constraints
539f3f6d7bSStella Laurenzo            print(e)
549f3f6d7bSStella Laurenzo
559f3f6d7bSStella Laurenzo        try:
569f3f6d7bSStella Laurenzo            IntegerSet.get(2, 1, [d0 - d1], [True, False])
579f3f6d7bSStella Laurenzo        except ValueError as e:
589f3f6d7bSStella Laurenzo            # CHECK: Expected the number of constraints to match that of equality flags
599f3f6d7bSStella Laurenzo            print(e)
609f3f6d7bSStella Laurenzo
619f3f6d7bSStella Laurenzo        try:
629f3f6d7bSStella Laurenzo            IntegerSet.get(2, 1, [0], [True])
639f3f6d7bSStella Laurenzo        except RuntimeError as e:
649f3f6d7bSStella Laurenzo            # CHECK: Invalid expression when attempting to create an IntegerSet
659f3f6d7bSStella Laurenzo            print(e)
669f3f6d7bSStella Laurenzo
679f3f6d7bSStella Laurenzo        try:
689f3f6d7bSStella Laurenzo            IntegerSet.get(2, 1, [None], [True])
699f3f6d7bSStella Laurenzo        except RuntimeError as e:
709f3f6d7bSStella Laurenzo            # CHECK: Invalid expression (None?) when attempting to create an IntegerSet
719f3f6d7bSStella Laurenzo            print(e)
729f3f6d7bSStella Laurenzo
739f3f6d7bSStella Laurenzo        try:
749f3f6d7bSStella Laurenzo            set0.get_replaced([d0], [s0], 1, 1)
759f3f6d7bSStella Laurenzo        except ValueError as e:
769f3f6d7bSStella Laurenzo            # CHECK: Expected the number of dimension replacement expressions to match that of dimensions
779f3f6d7bSStella Laurenzo            print(e)
789f3f6d7bSStella Laurenzo
799f3f6d7bSStella Laurenzo        try:
809f3f6d7bSStella Laurenzo            set0.get_replaced([d0, d1], [s0, s0], 1, 1)
819f3f6d7bSStella Laurenzo        except ValueError as e:
829f3f6d7bSStella Laurenzo            # CHECK: Expected the number of symbol replacement expressions to match that of symbols
839f3f6d7bSStella Laurenzo            print(e)
849f3f6d7bSStella Laurenzo
859f3f6d7bSStella Laurenzo        try:
869f3f6d7bSStella Laurenzo            set0.get_replaced([d0, 1], [s0], 1, 1)
879f3f6d7bSStella Laurenzo        except RuntimeError as e:
889f3f6d7bSStella Laurenzo            # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
899f3f6d7bSStella Laurenzo            print(e)
909f3f6d7bSStella Laurenzo
919f3f6d7bSStella Laurenzo        try:
929f3f6d7bSStella Laurenzo            set0.get_replaced([d0, d1], [None], 1, 1)
939f3f6d7bSStella Laurenzo        except RuntimeError as e:
949f3f6d7bSStella Laurenzo            # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
959f3f6d7bSStella Laurenzo            print(e)
969f3f6d7bSStella Laurenzo
979f3f6d7bSStella Laurenzo
989f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testIntegerSetProperties
99fc7594ccSAlex Zinenko@run
1009f3f6d7bSStella Laurenzodef testIntegerSetProperties():
1019f3f6d7bSStella Laurenzo    with Context():
1029f3f6d7bSStella Laurenzo        d0 = AffineDimExpr.get(0)
1039f3f6d7bSStella Laurenzo        d1 = AffineDimExpr.get(1)
1049f3f6d7bSStella Laurenzo        s0 = AffineSymbolExpr.get(0)
1059f3f6d7bSStella Laurenzo        c42 = AffineConstantExpr.get(42)
1069f3f6d7bSStella Laurenzo
1079f3f6d7bSStella Laurenzo        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
1089f3f6d7bSStella Laurenzo        # CHECK: 2
1099f3f6d7bSStella Laurenzo        print(set0.n_dims)
1109f3f6d7bSStella Laurenzo        # CHECK: 1
1119f3f6d7bSStella Laurenzo        print(set0.n_symbols)
1129f3f6d7bSStella Laurenzo        # CHECK: 3
1139f3f6d7bSStella Laurenzo        print(set0.n_inputs)
1149f3f6d7bSStella Laurenzo        # CHECK: 1
1159f3f6d7bSStella Laurenzo        print(set0.n_equalities)
1169f3f6d7bSStella Laurenzo        # CHECK: 2
1179f3f6d7bSStella Laurenzo        print(set0.n_inequalities)
1189f3f6d7bSStella Laurenzo
1199f3f6d7bSStella Laurenzo        # CHECK: 3
1209f3f6d7bSStella Laurenzo        print(len(set0.constraints))
1219f3f6d7bSStella Laurenzo
1229f3f6d7bSStella Laurenzo        # CHECK-DAG: d0 - d1 == 0
1239f3f6d7bSStella Laurenzo        # CHECK-DAG: s0 - 42 >= 0
1249f3f6d7bSStella Laurenzo        # CHECK-DAG: -d0 + s0 >= 0
1259f3f6d7bSStella Laurenzo        for cstr in set0.constraints:
126*f9008e63STobias Hieta            print(cstr.expr, end="")
1279f3f6d7bSStella Laurenzo            print(" == 0" if cstr.is_eq else " >= 0")
1289f3f6d7bSStella Laurenzo
129fc7594ccSAlex Zinenko
1308894c05bSIvan Kosarev# TODO-LABEL: TEST: testHash
131fc7594ccSAlex Zinenko@run
132fc7594ccSAlex Zinenkodef testHash():
133fc7594ccSAlex Zinenko    with Context():
134fc7594ccSAlex Zinenko        d0 = AffineDimExpr.get(0)
135fc7594ccSAlex Zinenko        d1 = AffineDimExpr.get(1)
136fc7594ccSAlex Zinenko        set = IntegerSet.get(2, 0, [d0 + d1], [True])
137fc7594ccSAlex Zinenko
138fc7594ccSAlex Zinenko        assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True]))
139fc7594ccSAlex Zinenko
140fc7594ccSAlex Zinenko        dictionary = dict()
141fc7594ccSAlex Zinenko        dictionary[set] = 42
142fc7594ccSAlex Zinenko        assert set in dictionary
143