1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import arith 5from mlir.dialects import func 6from mlir.dialects import memref 7from mlir.dialects import scf 8from mlir.passmanager import PassManager 9 10 11def constructAndPrintInModule(f): 12 print("\nTEST:", f.__name__) 13 with Context(), Location.unknown(): 14 module = Module.create() 15 with InsertionPoint(module.body): 16 f() 17 print(module) 18 return f 19 20 21# CHECK-LABEL: TEST: testSimpleLoop 22@constructAndPrintInModule 23def testSimpleLoop(): 24 index_type = IndexType.get() 25 26 @func.FuncOp.from_py_func(index_type, index_type, index_type) 27 def simple_loop(lb, ub, step): 28 loop = scf.ForOp(lb, ub, step, [lb, lb]) 29 with InsertionPoint(loop.body): 30 scf.YieldOp(loop.inner_iter_args) 31 return 32 33 34# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) 35# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] 36# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]]) 37# CHECK: scf.yield %[[I1]], %[[I2]] 38 39 40# CHECK-LABEL: TEST: testInductionVar 41@constructAndPrintInModule 42def testInductionVar(): 43 index_type = IndexType.get() 44 45 @func.FuncOp.from_py_func(index_type, index_type, index_type) 46 def induction_var(lb, ub, step): 47 loop = scf.ForOp(lb, ub, step, [lb]) 48 with InsertionPoint(loop.body): 49 scf.YieldOp([loop.induction_variable]) 50 return 51 52 53# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) 54# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] 55# CHECK: scf.yield %[[IV]] 56 57 58# CHECK-LABEL: TEST: testForSugar 59@constructAndPrintInModule 60def testForSugar(): 61 index_type = IndexType.get() 62 memref_t = MemRefType.get([10], index_type) 63 range = scf.for_ 64 65 # CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 66 # CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { 67 # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index 68 # CHECK: memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex> 69 # CHECK: } 70 # CHECK: return 71 # CHECK: } 72 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 73 def range_loop_1(lb, ub, step, memref_v): 74 for i in range(lb, ub, step): 75 add = arith.addi(i, i) 76 memref.store(add, memref_v, [i]) 77 78 scf.yield_([]) 79 80 # CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 81 # CHECK: %[[VAL_4:.*]] = arith.constant 10 : index 82 # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index 83 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] { 84 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index 85 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex> 86 # CHECK: } 87 # CHECK: return 88 # CHECK: } 89 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 90 def range_loop_2(lb, ub, step, memref_v): 91 for i in range(lb, 10, 1): 92 add = arith.addi(i, i) 93 memref.store(add, memref_v, [i]) 94 scf.yield_([]) 95 96 # CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 97 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 98 # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index 99 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] { 100 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index 101 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex> 102 # CHECK: } 103 # CHECK: return 104 # CHECK: } 105 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 106 def range_loop_3(lb, ub, step, memref_v): 107 for i in range(0, ub, 1): 108 add = arith.addi(i, i) 109 memref.store(add, memref_v, [i]) 110 scf.yield_([]) 111 112 # CHECK: func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 113 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 114 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index 115 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] { 116 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index 117 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex> 118 # CHECK: } 119 # CHECK: return 120 # CHECK: } 121 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 122 def range_loop_4(lb, ub, step, memref_v): 123 for i in range(0, 10, step): 124 add = arith.addi(i, i) 125 memref.store(add, memref_v, [i]) 126 scf.yield_([]) 127 128 # CHECK: func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 129 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 130 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index 131 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index 132 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] { 133 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index 134 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex> 135 # CHECK: } 136 # CHECK: return 137 # CHECK: } 138 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 139 def range_loop_5(lb, ub, step, memref_v): 140 for i in range(0, 10, 1): 141 add = arith.addi(i, i) 142 memref.store(add, memref_v, [i]) 143 scf.yield_([]) 144 145 # CHECK: func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 146 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 147 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index 148 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index 149 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] { 150 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index 151 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex> 152 # CHECK: } 153 # CHECK: return 154 # CHECK: } 155 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 156 def range_loop_6(lb, ub, step, memref_v): 157 for i in range(0, 10): 158 add = arith.addi(i, i) 159 memref.store(add, memref_v, [i]) 160 scf.yield_([]) 161 162 # CHECK: func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 163 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 164 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index 165 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index 166 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] { 167 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index 168 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex> 169 # CHECK: } 170 # CHECK: return 171 # CHECK: } 172 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 173 def range_loop_7(lb, ub, step, memref_v): 174 for i in range(10): 175 add = arith.addi(i, i) 176 memref.store(add, memref_v, [i]) 177 scf.yield_([]) 178 179 # CHECK: func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 180 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 181 # CHECK: %[[VAL_5:.*]] = arith.constant 0 : index 182 # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index 183 # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index 184 # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index 185 # CHECK: %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) { 186 # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index 187 # CHECK: scf.yield %[[VAL_9]] : index 188 # CHECK: } 189 # CHECK: memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex> 190 # CHECK: return 191 # CHECK: } 192 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 193 def loop_yield_1(lb, ub, step, memref_v): 194 sum = arith.ConstantOp.create_index(0) 195 c0 = arith.ConstantOp.create_index(0) 196 for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]): 197 loc_sum = arith.addi(loc_sum, i) 198 scf.yield_([loc_sum]) 199 memref.store(sum, memref_v, [c0]) 200 201 # CHECK: func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 202 # CHECK: %[[c0:.*]] = arith.constant 0 : index 203 # CHECK: %[[c2:.*]] = arith.constant 2 : index 204 # CHECK: %[[REF1:.*]] = arith.constant 0 : index 205 # CHECK: %[[REF2:.*]] = arith.constant 1 : index 206 # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index 207 # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index 208 # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index 209 # CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) { 210 # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index 211 # CHECK: %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index 212 # CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : index, index 213 # CHECK: } 214 # CHECK: return 215 # CHECK: } 216 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 217 def loop_yield_2(lb, ub, step, memref_v): 218 sum1 = arith.ConstantOp.create_index(0) 219 sum2 = arith.ConstantOp.create_index(2) 220 c0 = arith.ConstantOp.create_index(0) 221 c1 = arith.ConstantOp.create_index(1) 222 for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]): 223 loc_sum1 = arith.addi(loc_sum1, i) 224 loc_sum2 = arith.addi(loc_sum2, i) 225 scf.yield_([loc_sum1, loc_sum2]) 226 memref.store(sum1, memref_v, [c0]) 227 memref.store(sum2, memref_v, [c1]) 228 229 230@constructAndPrintInModule 231def testOpsAsArguments(): 232 index_type = IndexType.get() 233 callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private") 234 f = func.FuncOp("ops_as_arguments", ([], [])) 235 with InsertionPoint(f.add_entry_block()): 236 lb = arith.ConstantOp.create_index(0) 237 ub = arith.ConstantOp.create_index(42) 238 step = arith.ConstantOp.create_index(2) 239 iter_args = func.CallOp(callee, []) 240 loop = scf.ForOp(lb, ub, step, iter_args) 241 with InsertionPoint(loop.body): 242 scf.YieldOp(loop.inner_iter_args) 243 func.ReturnOp([]) 244 245 246# CHECK-LABEL: TEST: testOpsAsArguments 247# CHECK: func private @callee() -> (index, index) 248# CHECK: func @ops_as_arguments() { 249# CHECK: %[[LB:.*]] = arith.constant 0 250# CHECK: %[[UB:.*]] = arith.constant 42 251# CHECK: %[[STEP:.*]] = arith.constant 2 252# CHECK: %[[ARGS:.*]]:2 = call @callee() 253# CHECK: scf.for %arg0 = %c0 to %c42 step %c2 254# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1) 255# CHECK: scf.yield %{{.*}}, %{{.*}} 256# CHECK: return 257 258 259@constructAndPrintInModule 260def testIfWithoutElse(): 261 bool = IntegerType.get_signless(1) 262 i32 = IntegerType.get_signless(32) 263 264 @func.FuncOp.from_py_func(bool) 265 def simple_if(cond): 266 if_op = scf.IfOp(cond) 267 with InsertionPoint(if_op.then_block): 268 one = arith.ConstantOp(i32, 1) 269 add = arith.AddIOp(one, one) 270 scf.YieldOp([]) 271 return 272 273 274# CHECK: func @simple_if(%[[ARG0:.*]]: i1) 275# CHECK: scf.if %[[ARG0:.*]] 276# CHECK: %[[ONE:.*]] = arith.constant 1 277# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]] 278# CHECK: return 279 280 281@constructAndPrintInModule 282def testNestedIf(): 283 bool = IntegerType.get_signless(1) 284 i32 = IntegerType.get_signless(32) 285 286 @func.FuncOp.from_py_func(bool, bool) 287 def nested_if(b, c): 288 if_op = scf.IfOp(b) 289 with InsertionPoint(if_op.then_block) as ip: 290 if_op = scf.IfOp(c, ip=ip) 291 with InsertionPoint(if_op.then_block): 292 one = arith.ConstantOp(i32, 1) 293 add = arith.AddIOp(one, one) 294 scf.YieldOp([]) 295 scf.YieldOp([]) 296 return 297 298 299# CHECK: func @nested_if(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) 300# CHECK: scf.if %[[ARG0:.*]] 301# CHECK: scf.if %[[ARG1:.*]] 302# CHECK: %[[ONE:.*]] = arith.constant 1 303# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]] 304# CHECK: return 305 306 307@constructAndPrintInModule 308def testIfWithElse(): 309 bool = IntegerType.get_signless(1) 310 i32 = IntegerType.get_signless(32) 311 312 @func.FuncOp.from_py_func(bool) 313 def simple_if_else(cond): 314 if_op = scf.IfOp(cond, [i32, i32], hasElse=True) 315 with InsertionPoint(if_op.then_block): 316 x_true = arith.ConstantOp(i32, 0) 317 y_true = arith.ConstantOp(i32, 1) 318 scf.YieldOp([x_true, y_true]) 319 with InsertionPoint(if_op.else_block): 320 x_false = arith.ConstantOp(i32, 2) 321 y_false = arith.ConstantOp(i32, 3) 322 scf.YieldOp([x_false, y_false]) 323 add = arith.AddIOp(if_op.results[0], if_op.results[1]) 324 return 325 326 327# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1) 328# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]] 329# CHECK: %[[ZERO:.*]] = arith.constant 0 330# CHECK: %[[ONE:.*]] = arith.constant 1 331# CHECK: scf.yield %[[ZERO]], %[[ONE]] 332# CHECK: } else { 333# CHECK: %[[TWO:.*]] = arith.constant 2 334# CHECK: %[[THREE:.*]] = arith.constant 3 335# CHECK: scf.yield %[[TWO]], %[[THREE]] 336# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 337# CHECK: return 338