1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6 7def run(f): 8 print("\nTEST:", f.__name__) 9 f() 10 gc.collect() 11 assert Context._get_live_count() == 0 12 return f 13 14 15# CHECK-LABEL: TEST: testAffineMapCapsule 16@run 17def testAffineMapCapsule(): 18 with Context() as ctx: 19 am1 = AffineMap.get_empty(ctx) 20 # CHECK: mlir.ir.AffineMap._CAPIPtr 21 affine_map_capsule = am1._CAPIPtr 22 print(affine_map_capsule) 23 am2 = AffineMap._CAPICreate(affine_map_capsule) 24 assert am2 == am1 25 assert am2.context is ctx 26 27 28# CHECK-LABEL: TEST: testAffineMapGet 29@run 30def testAffineMapGet(): 31 with Context() as ctx: 32 d0 = AffineDimExpr.get(0) 33 d1 = AffineDimExpr.get(1) 34 c2 = AffineConstantExpr.get(2) 35 36 # CHECK: (d0, d1)[s0, s1, s2] -> () 37 map0 = AffineMap.get(2, 3, []) 38 print(map0) 39 40 # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2) 41 map1 = AffineMap.get(2, 3, [d1, c2]) 42 print(map1) 43 44 # CHECK: () -> (2) 45 map2 = AffineMap.get(0, 0, [c2]) 46 print(map2) 47 48 # CHECK: (d0, d1) -> (d0, d1) 49 map3 = AffineMap.get(2, 0, [d0, d1]) 50 print(map3) 51 52 # CHECK: (d0, d1) -> (d1) 53 map4 = AffineMap.get(2, 0, [d1]) 54 print(map4) 55 56 # CHECK: (d0, d1, d2) -> (d2, d0, d1) 57 map5 = AffineMap.get_permutation([2, 0, 1]) 58 print(map5) 59 60 assert map1 == AffineMap.get(2, 3, [d1, c2]) 61 assert AffineMap.get(0, 0, []) == AffineMap.get_empty() 62 assert map2 == AffineMap.get_constant(2) 63 assert map3 == AffineMap.get_identity(2) 64 assert map4 == AffineMap.get_minor_identity(2, 1) 65 66 try: 67 AffineMap.get(1, 1, [1]) 68 except RuntimeError as e: 69 # CHECK: Invalid expression when attempting to create an AffineMap 70 print(e) 71 72 try: 73 AffineMap.get(1, 1, [None]) 74 except RuntimeError as e: 75 # CHECK: Invalid expression (None?) when attempting to create an AffineMap 76 print(e) 77 78 try: 79 AffineMap.get_permutation([1, 0, 1]) 80 except RuntimeError as e: 81 # CHECK: Invalid permutation when attempting to create an AffineMap 82 print(e) 83 84 try: 85 map3.get_submap([42]) 86 except ValueError as e: 87 # CHECK: result position out of bounds 88 print(e) 89 90 try: 91 map3.get_minor_submap(42) 92 except ValueError as e: 93 # CHECK: number of results out of bounds 94 print(e) 95 96 try: 97 map3.get_major_submap(42) 98 except ValueError as e: 99 # CHECK: number of results out of bounds 100 print(e) 101 102 103# CHECK-LABEL: TEST: testAffineMapDerive 104@run 105def testAffineMapDerive(): 106 with Context() as ctx: 107 map5 = AffineMap.get_identity(5) 108 109 # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3) 110 map123 = map5.get_submap([1, 2, 3]) 111 print(map123) 112 113 # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1) 114 map01 = map5.get_major_submap(2) 115 print(map01) 116 117 # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4) 118 map34 = map5.get_minor_submap(2) 119 print(map34) 120 121 122# CHECK-LABEL: TEST: testAffineMapProperties 123@run 124def testAffineMapProperties(): 125 with Context(): 126 d0 = AffineDimExpr.get(0) 127 d1 = AffineDimExpr.get(1) 128 d2 = AffineDimExpr.get(2) 129 map1 = AffineMap.get(3, 0, [d2, d0]) 130 map2 = AffineMap.get(3, 0, [d2, d0, d1]) 131 map3 = AffineMap.get(3, 1, [d2, d0, d1]) 132 # CHECK: False 133 print(map1.is_permutation) 134 # CHECK: True 135 print(map1.is_projected_permutation) 136 # CHECK: True 137 print(map2.is_permutation) 138 # CHECK: True 139 print(map2.is_projected_permutation) 140 # CHECK: False 141 print(map3.is_permutation) 142 # CHECK: False 143 print(map3.is_projected_permutation) 144 145 146# CHECK-LABEL: TEST: testAffineMapExprs 147@run 148def testAffineMapExprs(): 149 with Context(): 150 d0 = AffineDimExpr.get(0) 151 d1 = AffineDimExpr.get(1) 152 d2 = AffineDimExpr.get(2) 153 map3 = AffineMap.get(3, 1, [d2, d0, d1]) 154 155 # CHECK: 3 156 print(map3.n_dims) 157 # CHECK: 4 158 print(map3.n_inputs) 159 # CHECK: 1 160 print(map3.n_symbols) 161 assert map3.n_inputs == map3.n_dims + map3.n_symbols 162 163 # CHECK: 3 164 print(len(map3.results)) 165 for expr in map3.results: 166 # CHECK: d2 167 # CHECK: d0 168 # CHECK: d1 169 print(expr) 170 for expr in map3.results[-1:-4:-1]: 171 # CHECK: d1 172 # CHECK: d0 173 # CHECK: d2 174 print(expr) 175 assert list(map3.results) == [d2, d0, d1] 176 177 178# CHECK-LABEL: TEST: testCompressUnusedSymbols 179@run 180def testCompressUnusedSymbols(): 181 with Context() as ctx: 182 d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2)) 183 s0, s1, s2 = ( 184 AffineSymbolExpr.get(0), 185 AffineSymbolExpr.get(1), 186 AffineSymbolExpr.get(2), 187 ) 188 maps = [ 189 AffineMap.get(3, 3, [d2, d0, d1]), 190 AffineMap.get(3, 3, [d2, d0 + s2, d1]), 191 AffineMap.get(3, 3, [d1, d2, d0]), 192 ] 193 194 compressed_maps = AffineMap.compress_unused_symbols(maps, ctx) 195 196 # CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1)) 197 # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1)) 198 # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0)) 199 print(maps) 200 201 # CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1)) 202 # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1)) 203 # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0)) 204 print(compressed_maps) 205 206 207# CHECK-LABEL: TEST: testReplace 208@run 209def testReplace(): 210 with Context() as ctx: 211 d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2)) 212 s0, s1, s2 = ( 213 AffineSymbolExpr.get(0), 214 AffineSymbolExpr.get(1), 215 AffineSymbolExpr.get(2), 216 ) 217 map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0]) 218 219 replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3) 220 replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3) 221 replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2) 222 223 # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42) 224 print(replace0) 225 226 # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0) 227 print(replace1) 228 229 # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0) 230 print(replace3) 231 232 233# CHECK-LABEL: TEST: testHash 234@run 235def testHash(): 236 with Context(): 237 d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1) 238 m1 = AffineMap.get(2, 0, [d0, d1]) 239 m2 = AffineMap.get(2, 0, [d1, d0]) 240 assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1])) 241 242 dictionary = dict() 243 dictionary[m1] = 1 244 dictionary[m2] = 2 245 assert m1 in dictionary 246