1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import func 5from mlir.dialects import arith 6from mlir.dialects import memref 7from mlir.dialects import affine 8import mlir.extras.types as T 9 10 11def constructAndPrintInModule(f): 12 print("\nTEST:", f.__name__) 13 with Context(), Location.unknown(): 14 module = Module.create() 15 with InsertionPoint(module.body): 16 f() 17 print(module) 18 return f 19 20 21# CHECK-LABEL: TEST: testAffineStoreOp 22@constructAndPrintInModule 23def testAffineStoreOp(): 24 f32 = F32Type.get() 25 index_type = IndexType.get() 26 memref_type_out = MemRefType.get([12, 12], f32) 27 28 # CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> { 29 @func.FuncOp.from_py_func(index_type) 30 def affine_store_test(arg0): 31 # CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32> 32 mem = memref.AllocOp(memref_type_out, [], []).result 33 34 d0 = AffineDimExpr.get(0) 35 s0 = AffineSymbolExpr.get(0) 36 map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1]) 37 38 # CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32 39 a1 = arith.ConstantOp(f32, 2.1) 40 41 # CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32> 42 affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map) 43 44 return mem 45 46 47# CHECK-LABEL: TEST: testAffineDelinearizeInfer 48@constructAndPrintInModule 49def testAffineDelinearizeInfer(): 50 # CHECK: %[[C1:.*]] = arith.constant 1 : index 51 c1 = arith.ConstantOp(T.index(), 1) 52 # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (2, 3) : index, index 53 two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [2, 3]) 54 55 56# CHECK-LABEL: TEST: testAffineLoadOp 57@constructAndPrintInModule 58def testAffineLoadOp(): 59 f32 = F32Type.get() 60 index_type = IndexType.get() 61 memref_type_in = MemRefType.get([10, 10], f32) 62 63 # CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 { 64 @func.FuncOp.from_py_func(memref_type_in, index_type) 65 def affine_load_test(I, arg0): 66 d0 = AffineDimExpr.get(0) 67 s0 = AffineSymbolExpr.get(0) 68 map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1]) 69 70 # CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32> 71 a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map) 72 73 return a1 74 75 76# CHECK-LABEL: TEST: testAffineForOp 77@constructAndPrintInModule 78def testAffineForOp(): 79 f32 = F32Type.get() 80 index_type = IndexType.get() 81 memref_type = MemRefType.get([1024], f32) 82 83 # CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)> 84 # CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)> 85 # CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) { 86 @func.FuncOp.from_py_func(memref_type) 87 def affine_for_op_test(buffer): 88 # CHECK: %[[C1:.*]] = arith.constant 1 : index 89 c1 = arith.ConstantOp(index_type, 1) 90 # CHECK: %[[C2:.*]] = arith.constant 2 : index 91 c2 = arith.ConstantOp(index_type, 2) 92 # CHECK: %[[C3:.*]] = arith.constant 3 : index 93 c3 = arith.ConstantOp(index_type, 3) 94 # CHECK: %[[C9:.*]] = arith.constant 9 : index 95 c9 = arith.ConstantOp(index_type, 9) 96 # CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32 97 ac0 = AffineConstantExpr.get(0) 98 99 d0 = AffineDimExpr.get(0) 100 d1 = AffineDimExpr.get(1) 101 s0 = AffineSymbolExpr.get(0) 102 lb = AffineMap.get(1, 1, [ac0, d0 + s0]) 103 ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1]) 104 sum_0 = arith.ConstantOp(f32, 0.0) 105 106 # CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) { 107 sum = affine.AffineForOp( 108 lb, 109 ub, 110 2, 111 iter_args=[sum_0], 112 lower_bound_operands=[c2, c3], 113 upper_bound_operands=[c9, c1], 114 ) 115 116 with InsertionPoint(sum.body): 117 # CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32> 118 tmp = memref.LoadOp(buffer, [sum.induction_variable]) 119 sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp) 120 affine.AffineYieldOp([sum_next]) 121 122 123# CHECK-LABEL: TEST: testAffineForOpErrors 124@constructAndPrintInModule 125def testAffineForOpErrors(): 126 c1 = arith.ConstantOp(T.index(), 1) 127 c2 = arith.ConstantOp(T.index(), 2) 128 c3 = arith.ConstantOp(T.index(), 3) 129 d0 = AffineDimExpr.get(0) 130 131 try: 132 affine.AffineForOp( 133 c1, 134 c2, 135 1, 136 lower_bound_operands=[c3], 137 upper_bound_operands=[], 138 ) 139 except ValueError as e: 140 assert ( 141 e.args[0] 142 == "Either a concrete lower bound or an AffineMap in combination with lower bound operands, but not both, is supported." 143 ) 144 145 try: 146 affine.AffineForOp( 147 AffineMap.get_constant(1), 148 c2, 149 1, 150 lower_bound_operands=[c3, c3], 151 upper_bound_operands=[], 152 ) 153 except ValueError as e: 154 assert ( 155 e.args[0] 156 == "Wrong number of lower bound operands passed to AffineForOp; Expected 0, got 2." 157 ) 158 159 try: 160 two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [1, 1]) 161 affine.AffineForOp( 162 two_indices, 163 c2, 164 1, 165 lower_bound_operands=[], 166 upper_bound_operands=[], 167 ) 168 except ValueError as e: 169 assert e.args[0] == "Only a single concrete value is supported for lower bound." 170 171 try: 172 affine.AffineForOp( 173 1.0, 174 c2, 175 1, 176 lower_bound_operands=[], 177 upper_bound_operands=[], 178 ) 179 except ValueError as e: 180 assert e.args[0] == "lower bound must be int | ResultValueT | AffineMap." 181 182 183@constructAndPrintInModule 184def testForSugar(): 185 memref_t = T.memref(10, T.index()) 186 range = affine.for_ 187 188 # CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> (d0)> 189 190 # CHECK-LABEL: func.func @range_loop_1( 191 # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) { 192 # CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to #[[$ATTR_2]](%[[VAL_1]]) { 193 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index 194 # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex> 195 # CHECK: } 196 # CHECK: return 197 # CHECK: } 198 @func.FuncOp.from_py_func(T.index(), T.index(), memref_t) 199 def range_loop_1(lb, ub, memref_v): 200 for i in range(lb, ub, step=1): 201 add = arith.addi(i, i) 202 memref.store(add, memref_v, [i]) 203 204 affine.yield_([]) 205 206 # CHECK-LABEL: func.func @range_loop_2( 207 # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) { 208 # CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to 10 { 209 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index 210 # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex> 211 # CHECK: } 212 # CHECK: return 213 # CHECK: } 214 @func.FuncOp.from_py_func(T.index(), T.index(), memref_t) 215 def range_loop_2(lb, ub, memref_v): 216 for i in range(lb, 10, step=1): 217 add = arith.addi(i, i) 218 memref.store(add, memref_v, [i]) 219 affine.yield_([]) 220 221 # CHECK-LABEL: func.func @range_loop_3( 222 # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) { 223 # CHECK: affine.for %[[VAL_3:.*]] = 0 to #[[$ATTR_2]](%[[VAL_1]]) { 224 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index 225 # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex> 226 # CHECK: } 227 # CHECK: return 228 # CHECK: } 229 @func.FuncOp.from_py_func(T.index(), T.index(), memref_t) 230 def range_loop_3(lb, ub, memref_v): 231 for i in range(0, ub, step=1): 232 add = arith.addi(i, i) 233 memref.store(add, memref_v, [i]) 234 affine.yield_([]) 235 236 # CHECK-LABEL: func.func @range_loop_4( 237 # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) { 238 # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 { 239 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index 240 # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex> 241 # CHECK: } 242 # CHECK: return 243 # CHECK: } 244 @func.FuncOp.from_py_func(T.index(), T.index(), memref_t) 245 def range_loop_4(lb, ub, memref_v): 246 for i in range(0, 10, step=1): 247 add = arith.addi(i, i) 248 memref.store(add, memref_v, [i]) 249 affine.yield_([]) 250 251 # CHECK-LABEL: func.func @range_loop_8( 252 # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) { 253 # CHECK: %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) { 254 # CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index 255 # CHECK: memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex> 256 # CHECK: affine.yield %[[VAL_5]] : memref<10xindex> 257 # CHECK: } 258 # CHECK: return 259 # CHECK: } 260 @func.FuncOp.from_py_func(T.index(), T.index(), memref_t) 261 def range_loop_8(lb, ub, memref_v): 262 for i, it in range(0, 10, iter_args=[memref_v]): 263 add = arith.addi(i, i) 264 memref.store(add, it, [i]) 265 affine.yield_([it]) 266 267 268# CHECK-LABEL: TEST: testAffineIfWithoutElse 269@constructAndPrintInModule 270def testAffineIfWithoutElse(): 271 index = IndexType.get() 272 i32 = IntegerType.get_signless(32) 273 d0 = AffineDimExpr.get(0) 274 275 # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)> 276 cond = IntegerSet.get(1, 0, [d0 - 5], [False]) 277 278 # CHECK-LABEL: func.func @simple_affine_if( 279 # CHECK-SAME: %[[VAL_0:.*]]: index) { 280 # CHECK: affine.if #[[$SET0]](%[[VAL_0]]) { 281 # CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 282 # CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32 283 # CHECK: } 284 # CHECK: return 285 # CHECK: } 286 @func.FuncOp.from_py_func(index) 287 def simple_affine_if(cond_operands): 288 if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands]) 289 with InsertionPoint(if_op.then_block): 290 one = arith.ConstantOp(i32, 1) 291 add = arith.AddIOp(one, one) 292 affine.AffineYieldOp([]) 293 return 294 295 296# CHECK-LABEL: TEST: testAffineIfWithElse 297@constructAndPrintInModule 298def testAffineIfWithElse(): 299 index = IndexType.get() 300 i32 = IntegerType.get_signless(32) 301 d0 = AffineDimExpr.get(0) 302 303 # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)> 304 cond = IntegerSet.get(1, 0, [d0 - 5], [False]) 305 306 # CHECK-LABEL: func.func @simple_affine_if_else( 307 # CHECK-SAME: %[[VAL_0:.*]]: index) { 308 # CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) { 309 # CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32 310 # CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32 311 # CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32 312 # CHECK: } else { 313 # CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32 314 # CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32 315 # CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32 316 # CHECK: } 317 # CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32 318 # CHECK: return 319 # CHECK: } 320 321 @func.FuncOp.from_py_func(index) 322 def simple_affine_if_else(cond_operands): 323 if_op = affine.AffineIfOp( 324 cond, [i32, i32], cond_operands=[cond_operands], has_else=True 325 ) 326 with InsertionPoint(if_op.then_block): 327 x_true = arith.ConstantOp(i32, 0) 328 y_true = arith.ConstantOp(i32, 1) 329 affine.AffineYieldOp([x_true, y_true]) 330 with InsertionPoint(if_op.else_block): 331 x_false = arith.ConstantOp(i32, 2) 332 y_false = arith.ConstantOp(i32, 3) 333 affine.AffineYieldOp([x_false, y_false]) 334 add = arith.AddIOp(if_op.results[0], if_op.results[1]) 335 return 336