18c1b785cSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s 28c1b785cSAlex Zinenko 38c1b785cSAlex Zinenkofrom mlir.ir import * 4a54f4eaeSMogballfrom mlir.dialects import arith 523aa5a74SRiver Riddlefrom mlir.dialects import func 6e9453f3cSMaksim Leventalfrom mlir.dialects import memref 78c1b785cSAlex Zinenkofrom mlir.dialects import scf 8e9453f3cSMaksim Leventalfrom mlir.passmanager import PassManager 98c1b785cSAlex Zinenko 108c1b785cSAlex Zinenko 11b164f23cSAlex Zinenkodef constructAndPrintInModule(f): 128c1b785cSAlex Zinenko print("\nTEST:", f.__name__) 13b164f23cSAlex Zinenko with Context(), Location.unknown(): 14b164f23cSAlex Zinenko module = Module.create() 15b164f23cSAlex Zinenko with InsertionPoint(module.body): 168c1b785cSAlex Zinenko f() 17b164f23cSAlex Zinenko print(module) 188c1b785cSAlex Zinenko return f 198c1b785cSAlex Zinenko 208c1b785cSAlex Zinenko 218c1b785cSAlex Zinenko# CHECK-LABEL: TEST: testSimpleLoop 22b164f23cSAlex Zinenko@constructAndPrintInModule 238c1b785cSAlex Zinenkodef testSimpleLoop(): 248c1b785cSAlex Zinenko index_type = IndexType.get() 258c1b785cSAlex Zinenko 2636550692SRiver Riddle @func.FuncOp.from_py_func(index_type, index_type, index_type) 278c1b785cSAlex Zinenko def simple_loop(lb, ub, step): 288c1b785cSAlex Zinenko loop = scf.ForOp(lb, ub, step, [lb, lb]) 298c1b785cSAlex Zinenko with InsertionPoint(loop.body): 308c1b785cSAlex Zinenko scf.YieldOp(loop.inner_iter_args) 318c1b785cSAlex Zinenko return 328c1b785cSAlex Zinenko 33b164f23cSAlex Zinenko 348c1b785cSAlex Zinenko# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) 358c1b785cSAlex Zinenko# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] 368c1b785cSAlex Zinenko# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]]) 378c1b785cSAlex Zinenko# CHECK: scf.yield %[[I1]], %[[I2]] 388c1b785cSAlex Zinenko 398c1b785cSAlex Zinenko 408c1b785cSAlex Zinenko# CHECK-LABEL: TEST: testInductionVar 41b164f23cSAlex Zinenko@constructAndPrintInModule 428c1b785cSAlex Zinenkodef testInductionVar(): 438c1b785cSAlex Zinenko index_type = IndexType.get() 448c1b785cSAlex Zinenko 4536550692SRiver Riddle @func.FuncOp.from_py_func(index_type, index_type, index_type) 468c1b785cSAlex Zinenko def induction_var(lb, ub, step): 478c1b785cSAlex Zinenko loop = scf.ForOp(lb, ub, step, [lb]) 488c1b785cSAlex Zinenko with InsertionPoint(loop.body): 498c1b785cSAlex Zinenko scf.YieldOp([loop.induction_variable]) 508c1b785cSAlex Zinenko return 518c1b785cSAlex Zinenko 52b164f23cSAlex Zinenko 538c1b785cSAlex Zinenko# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) 548c1b785cSAlex Zinenko# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] 558c1b785cSAlex Zinenko# CHECK: scf.yield %[[IV]] 56b164f23cSAlex Zinenko 57b164f23cSAlex Zinenko 5827c6d55cSMaksim Levental# CHECK-LABEL: TEST: testForSugar 5927c6d55cSMaksim Levental@constructAndPrintInModule 6027c6d55cSMaksim Leventaldef testForSugar(): 6127c6d55cSMaksim Levental index_type = IndexType.get() 62e9453f3cSMaksim Levental memref_t = MemRefType.get([10], index_type) 6327c6d55cSMaksim Levental range = scf.for_ 6427c6d55cSMaksim Levental 65e9453f3cSMaksim Levental # CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 66e9453f3cSMaksim Levental # CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { 67e9453f3cSMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index 68e9453f3cSMaksim Levental # CHECK: memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex> 6927c6d55cSMaksim Levental # CHECK: } 7027c6d55cSMaksim Levental # CHECK: return 7127c6d55cSMaksim Levental # CHECK: } 72e9453f3cSMaksim Levental @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 73e9453f3cSMaksim Levental def range_loop_1(lb, ub, step, memref_v): 74e9453f3cSMaksim Levental for i in range(lb, ub, step): 75e9453f3cSMaksim Levental add = arith.addi(i, i) 76e9453f3cSMaksim Levental memref.store(add, memref_v, [i]) 77e9453f3cSMaksim Levental 78e9453f3cSMaksim Levental scf.yield_([]) 79e9453f3cSMaksim Levental 80e9453f3cSMaksim Levental # CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 81e9453f3cSMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.constant 10 : index 82e9453f3cSMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index 83e9453f3cSMaksim Levental # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] { 84e9453f3cSMaksim Levental # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index 85e9453f3cSMaksim Levental # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex> 86e9453f3cSMaksim Levental # CHECK: } 87e9453f3cSMaksim Levental # CHECK: return 88e9453f3cSMaksim Levental # CHECK: } 89e9453f3cSMaksim Levental @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 90e9453f3cSMaksim Levental def range_loop_2(lb, ub, step, memref_v): 91e9453f3cSMaksim Levental for i in range(lb, 10, 1): 92e9453f3cSMaksim Levental add = arith.addi(i, i) 93e9453f3cSMaksim Levental memref.store(add, memref_v, [i]) 94e9453f3cSMaksim Levental scf.yield_([]) 95e9453f3cSMaksim Levental 96e9453f3cSMaksim Levental # CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 97e9453f3cSMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 98e9453f3cSMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index 99e9453f3cSMaksim Levental # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] { 100e9453f3cSMaksim Levental # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index 101e9453f3cSMaksim Levental # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex> 102e9453f3cSMaksim Levental # CHECK: } 103e9453f3cSMaksim Levental # CHECK: return 104e9453f3cSMaksim Levental # CHECK: } 105e9453f3cSMaksim Levental @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 106e9453f3cSMaksim Levental def range_loop_3(lb, ub, step, memref_v): 107e9453f3cSMaksim Levental for i in range(0, ub, 1): 108e9453f3cSMaksim Levental add = arith.addi(i, i) 109e9453f3cSMaksim Levental memref.store(add, memref_v, [i]) 110e9453f3cSMaksim Levental scf.yield_([]) 111e9453f3cSMaksim Levental 112e9453f3cSMaksim Levental # CHECK: func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 113e9453f3cSMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 114e9453f3cSMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index 115e9453f3cSMaksim Levental # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] { 116e9453f3cSMaksim Levental # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index 117e9453f3cSMaksim Levental # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex> 118e9453f3cSMaksim Levental # CHECK: } 119e9453f3cSMaksim Levental # CHECK: return 120e9453f3cSMaksim Levental # CHECK: } 121e9453f3cSMaksim Levental @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 122e9453f3cSMaksim Levental def range_loop_4(lb, ub, step, memref_v): 123e9453f3cSMaksim Levental for i in range(0, 10, step): 124e9453f3cSMaksim Levental add = arith.addi(i, i) 125e9453f3cSMaksim Levental memref.store(add, memref_v, [i]) 126e9453f3cSMaksim Levental scf.yield_([]) 127e9453f3cSMaksim Levental 128e9453f3cSMaksim Levental # CHECK: func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 129e9453f3cSMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 130e9453f3cSMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index 131e9453f3cSMaksim Levental # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index 132e9453f3cSMaksim Levental # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] { 133e9453f3cSMaksim Levental # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index 134e9453f3cSMaksim Levental # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex> 135e9453f3cSMaksim Levental # CHECK: } 136e9453f3cSMaksim Levental # CHECK: return 137e9453f3cSMaksim Levental # CHECK: } 138e9453f3cSMaksim Levental @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 139e9453f3cSMaksim Levental def range_loop_5(lb, ub, step, memref_v): 140e9453f3cSMaksim Levental for i in range(0, 10, 1): 141e9453f3cSMaksim Levental add = arith.addi(i, i) 142e9453f3cSMaksim Levental memref.store(add, memref_v, [i]) 143e9453f3cSMaksim Levental scf.yield_([]) 144e9453f3cSMaksim Levental 145e9453f3cSMaksim Levental # CHECK: func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 146e9453f3cSMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 147e9453f3cSMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index 148e9453f3cSMaksim Levental # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index 149e9453f3cSMaksim Levental # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] { 150e9453f3cSMaksim Levental # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index 151e9453f3cSMaksim Levental # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex> 152e9453f3cSMaksim Levental # CHECK: } 153e9453f3cSMaksim Levental # CHECK: return 154e9453f3cSMaksim Levental # CHECK: } 155e9453f3cSMaksim Levental @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 156e9453f3cSMaksim Levental def range_loop_6(lb, ub, step, memref_v): 157e9453f3cSMaksim Levental for i in range(0, 10): 158e9453f3cSMaksim Levental add = arith.addi(i, i) 159e9453f3cSMaksim Levental memref.store(add, memref_v, [i]) 160e9453f3cSMaksim Levental scf.yield_([]) 161e9453f3cSMaksim Levental 162e9453f3cSMaksim Levental # CHECK: func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 163e9453f3cSMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 164e9453f3cSMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index 165e9453f3cSMaksim Levental # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index 166e9453f3cSMaksim Levental # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] { 167e9453f3cSMaksim Levental # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index 168e9453f3cSMaksim Levental # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex> 169e9453f3cSMaksim Levental # CHECK: } 170e9453f3cSMaksim Levental # CHECK: return 171e9453f3cSMaksim Levental # CHECK: } 172e9453f3cSMaksim Levental @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 173e9453f3cSMaksim Levental def range_loop_7(lb, ub, step, memref_v): 174e9453f3cSMaksim Levental for i in range(10): 175e9453f3cSMaksim Levental add = arith.addi(i, i) 176e9453f3cSMaksim Levental memref.store(add, memref_v, [i]) 177e9453f3cSMaksim Levental scf.yield_([]) 17827c6d55cSMaksim Levental 1797f58ffd0SGuray Ozen # CHECK: func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 1807f58ffd0SGuray Ozen # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index 1817f58ffd0SGuray Ozen # CHECK: %[[VAL_5:.*]] = arith.constant 0 : index 1827f58ffd0SGuray Ozen # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index 1837f58ffd0SGuray Ozen # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index 1847f58ffd0SGuray Ozen # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index 1857f58ffd0SGuray Ozen # CHECK: %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) { 1867f58ffd0SGuray Ozen # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index 1877f58ffd0SGuray Ozen # CHECK: scf.yield %[[VAL_9]] : index 1887f58ffd0SGuray Ozen # CHECK: } 1897f58ffd0SGuray Ozen # CHECK: memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex> 1907f58ffd0SGuray Ozen # CHECK: return 1917f58ffd0SGuray Ozen # CHECK: } 1927f58ffd0SGuray Ozen @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 1937f58ffd0SGuray Ozen def loop_yield_1(lb, ub, step, memref_v): 1947f58ffd0SGuray Ozen sum = arith.ConstantOp.create_index(0) 1957f58ffd0SGuray Ozen c0 = arith.ConstantOp.create_index(0) 1967f58ffd0SGuray Ozen for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]): 1977f58ffd0SGuray Ozen loc_sum = arith.addi(loc_sum, i) 1987f58ffd0SGuray Ozen scf.yield_([loc_sum]) 1997f58ffd0SGuray Ozen memref.store(sum, memref_v, [c0]) 2007f58ffd0SGuray Ozen 2017f58ffd0SGuray Ozen # CHECK: func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) { 2027f58ffd0SGuray Ozen # CHECK: %[[c0:.*]] = arith.constant 0 : index 2037f58ffd0SGuray Ozen # CHECK: %[[c2:.*]] = arith.constant 2 : index 2047f58ffd0SGuray Ozen # CHECK: %[[REF1:.*]] = arith.constant 0 : index 2057f58ffd0SGuray Ozen # CHECK: %[[REF2:.*]] = arith.constant 1 : index 2067f58ffd0SGuray Ozen # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index 2077f58ffd0SGuray Ozen # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index 2087f58ffd0SGuray Ozen # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index 2097f58ffd0SGuray Ozen # CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) { 2107f58ffd0SGuray Ozen # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index 2117f58ffd0SGuray Ozen # CHECK: %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index 2127f58ffd0SGuray Ozen # CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : index, index 2137f58ffd0SGuray Ozen # CHECK: } 2147f58ffd0SGuray Ozen # CHECK: return 2157f58ffd0SGuray Ozen # CHECK: } 2167f58ffd0SGuray Ozen @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t) 2177f58ffd0SGuray Ozen def loop_yield_2(lb, ub, step, memref_v): 2187f58ffd0SGuray Ozen sum1 = arith.ConstantOp.create_index(0) 2197f58ffd0SGuray Ozen sum2 = arith.ConstantOp.create_index(2) 2207f58ffd0SGuray Ozen c0 = arith.ConstantOp.create_index(0) 2217f58ffd0SGuray Ozen c1 = arith.ConstantOp.create_index(1) 2227f58ffd0SGuray Ozen for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]): 2237f58ffd0SGuray Ozen loc_sum1 = arith.addi(loc_sum1, i) 2247f58ffd0SGuray Ozen loc_sum2 = arith.addi(loc_sum2, i) 2257f58ffd0SGuray Ozen scf.yield_([loc_sum1, loc_sum2]) 2267f58ffd0SGuray Ozen memref.store(sum1, memref_v, [c0]) 2277f58ffd0SGuray Ozen memref.store(sum2, memref_v, [c1]) 2287f58ffd0SGuray Ozen 22927c6d55cSMaksim Levental 230b164f23cSAlex Zinenko@constructAndPrintInModule 231b164f23cSAlex Zinenkodef testOpsAsArguments(): 232b164f23cSAlex Zinenko index_type = IndexType.get() 233f9008e63STobias Hieta callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private") 23436550692SRiver Riddle f = func.FuncOp("ops_as_arguments", ([], [])) 23523aa5a74SRiver Riddle with InsertionPoint(f.add_entry_block()): 236a54f4eaeSMogball lb = arith.ConstantOp.create_index(0) 237a54f4eaeSMogball ub = arith.ConstantOp.create_index(42) 238a54f4eaeSMogball step = arith.ConstantOp.create_index(2) 23923aa5a74SRiver Riddle iter_args = func.CallOp(callee, []) 240b164f23cSAlex Zinenko loop = scf.ForOp(lb, ub, step, iter_args) 241b164f23cSAlex Zinenko with InsertionPoint(loop.body): 242b164f23cSAlex Zinenko scf.YieldOp(loop.inner_iter_args) 24323aa5a74SRiver Riddle func.ReturnOp([]) 244b164f23cSAlex Zinenko 245b164f23cSAlex Zinenko 246b164f23cSAlex Zinenko# CHECK-LABEL: TEST: testOpsAsArguments 247b164f23cSAlex Zinenko# CHECK: func private @callee() -> (index, index) 248b164f23cSAlex Zinenko# CHECK: func @ops_as_arguments() { 249a54f4eaeSMogball# CHECK: %[[LB:.*]] = arith.constant 0 250a54f4eaeSMogball# CHECK: %[[UB:.*]] = arith.constant 42 251a54f4eaeSMogball# CHECK: %[[STEP:.*]] = arith.constant 2 252b164f23cSAlex Zinenko# CHECK: %[[ARGS:.*]]:2 = call @callee() 253b164f23cSAlex Zinenko# CHECK: scf.for %arg0 = %c0 to %c42 step %c2 254b164f23cSAlex Zinenko# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1) 255b164f23cSAlex Zinenko# CHECK: scf.yield %{{.*}}, %{{.*}} 256b164f23cSAlex Zinenko# CHECK: return 257036088fdSchhzh123 258036088fdSchhzh123 259036088fdSchhzh123@constructAndPrintInModule 260036088fdSchhzh123def testIfWithoutElse(): 261036088fdSchhzh123 bool = IntegerType.get_signless(1) 262036088fdSchhzh123 i32 = IntegerType.get_signless(32) 263036088fdSchhzh123 26436550692SRiver Riddle @func.FuncOp.from_py_func(bool) 265036088fdSchhzh123 def simple_if(cond): 266036088fdSchhzh123 if_op = scf.IfOp(cond) 267036088fdSchhzh123 with InsertionPoint(if_op.then_block): 268036088fdSchhzh123 one = arith.ConstantOp(i32, 1) 269036088fdSchhzh123 add = arith.AddIOp(one, one) 270036088fdSchhzh123 scf.YieldOp([]) 271036088fdSchhzh123 return 272036088fdSchhzh123 273036088fdSchhzh123 274036088fdSchhzh123# CHECK: func @simple_if(%[[ARG0:.*]]: i1) 275036088fdSchhzh123# CHECK: scf.if %[[ARG0:.*]] 276036088fdSchhzh123# CHECK: %[[ONE:.*]] = arith.constant 1 277036088fdSchhzh123# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]] 278036088fdSchhzh123# CHECK: return 279036088fdSchhzh123 280036088fdSchhzh123 281036088fdSchhzh123@constructAndPrintInModule 282*ad89e617SMatt Hofmanndef testNestedIf(): 283*ad89e617SMatt Hofmann bool = IntegerType.get_signless(1) 284*ad89e617SMatt Hofmann i32 = IntegerType.get_signless(32) 285*ad89e617SMatt Hofmann 286*ad89e617SMatt Hofmann @func.FuncOp.from_py_func(bool, bool) 287*ad89e617SMatt Hofmann def nested_if(b, c): 288*ad89e617SMatt Hofmann if_op = scf.IfOp(b) 289*ad89e617SMatt Hofmann with InsertionPoint(if_op.then_block) as ip: 290*ad89e617SMatt Hofmann if_op = scf.IfOp(c, ip=ip) 291*ad89e617SMatt Hofmann with InsertionPoint(if_op.then_block): 292*ad89e617SMatt Hofmann one = arith.ConstantOp(i32, 1) 293*ad89e617SMatt Hofmann add = arith.AddIOp(one, one) 294*ad89e617SMatt Hofmann scf.YieldOp([]) 295*ad89e617SMatt Hofmann scf.YieldOp([]) 296*ad89e617SMatt Hofmann return 297*ad89e617SMatt Hofmann 298*ad89e617SMatt Hofmann 299*ad89e617SMatt Hofmann# CHECK: func @nested_if(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) 300*ad89e617SMatt Hofmann# CHECK: scf.if %[[ARG0:.*]] 301*ad89e617SMatt Hofmann# CHECK: scf.if %[[ARG1:.*]] 302*ad89e617SMatt Hofmann# CHECK: %[[ONE:.*]] = arith.constant 1 303*ad89e617SMatt Hofmann# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]] 304*ad89e617SMatt Hofmann# CHECK: return 305*ad89e617SMatt Hofmann 306*ad89e617SMatt Hofmann 307*ad89e617SMatt Hofmann@constructAndPrintInModule 308036088fdSchhzh123def testIfWithElse(): 309036088fdSchhzh123 bool = IntegerType.get_signless(1) 310036088fdSchhzh123 i32 = IntegerType.get_signless(32) 311036088fdSchhzh123 31236550692SRiver Riddle @func.FuncOp.from_py_func(bool) 313036088fdSchhzh123 def simple_if_else(cond): 314036088fdSchhzh123 if_op = scf.IfOp(cond, [i32, i32], hasElse=True) 315036088fdSchhzh123 with InsertionPoint(if_op.then_block): 316036088fdSchhzh123 x_true = arith.ConstantOp(i32, 0) 317036088fdSchhzh123 y_true = arith.ConstantOp(i32, 1) 318036088fdSchhzh123 scf.YieldOp([x_true, y_true]) 319036088fdSchhzh123 with InsertionPoint(if_op.else_block): 320036088fdSchhzh123 x_false = arith.ConstantOp(i32, 2) 321036088fdSchhzh123 y_false = arith.ConstantOp(i32, 3) 322036088fdSchhzh123 scf.YieldOp([x_false, y_false]) 323036088fdSchhzh123 add = arith.AddIOp(if_op.results[0], if_op.results[1]) 324036088fdSchhzh123 return 325036088fdSchhzh123 326036088fdSchhzh123 327036088fdSchhzh123# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1) 328036088fdSchhzh123# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]] 329036088fdSchhzh123# CHECK: %[[ZERO:.*]] = arith.constant 0 330036088fdSchhzh123# CHECK: %[[ONE:.*]] = arith.constant 1 331036088fdSchhzh123# CHECK: scf.yield %[[ZERO]], %[[ONE]] 332036088fdSchhzh123# CHECK: } else { 333036088fdSchhzh123# CHECK: %[[TWO:.*]] = arith.constant 2 334036088fdSchhzh123# CHECK: %[[THREE:.*]] = arith.constant 3 335036088fdSchhzh123# CHECK: scf.yield %[[TWO]], %[[THREE]] 336036088fdSchhzh123# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 337036088fdSchhzh123# CHECK: return 338