xref: /llvm-project/mlir/test/python/dialects/scf.py (revision ad89e617c703239518187912540b8ea811dc2eda)
18c1b785cSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s
28c1b785cSAlex Zinenko
38c1b785cSAlex Zinenkofrom mlir.ir import *
4a54f4eaeSMogballfrom mlir.dialects import arith
523aa5a74SRiver Riddlefrom mlir.dialects import func
6e9453f3cSMaksim Leventalfrom mlir.dialects import memref
78c1b785cSAlex Zinenkofrom mlir.dialects import scf
8e9453f3cSMaksim Leventalfrom mlir.passmanager import PassManager
98c1b785cSAlex Zinenko
108c1b785cSAlex Zinenko
11b164f23cSAlex Zinenkodef constructAndPrintInModule(f):
128c1b785cSAlex Zinenko    print("\nTEST:", f.__name__)
13b164f23cSAlex Zinenko    with Context(), Location.unknown():
14b164f23cSAlex Zinenko        module = Module.create()
15b164f23cSAlex Zinenko        with InsertionPoint(module.body):
168c1b785cSAlex Zinenko            f()
17b164f23cSAlex Zinenko        print(module)
188c1b785cSAlex Zinenko    return f
198c1b785cSAlex Zinenko
208c1b785cSAlex Zinenko
218c1b785cSAlex Zinenko# CHECK-LABEL: TEST: testSimpleLoop
22b164f23cSAlex Zinenko@constructAndPrintInModule
238c1b785cSAlex Zinenkodef testSimpleLoop():
248c1b785cSAlex Zinenko    index_type = IndexType.get()
258c1b785cSAlex Zinenko
2636550692SRiver Riddle    @func.FuncOp.from_py_func(index_type, index_type, index_type)
278c1b785cSAlex Zinenko    def simple_loop(lb, ub, step):
288c1b785cSAlex Zinenko        loop = scf.ForOp(lb, ub, step, [lb, lb])
298c1b785cSAlex Zinenko        with InsertionPoint(loop.body):
308c1b785cSAlex Zinenko            scf.YieldOp(loop.inner_iter_args)
318c1b785cSAlex Zinenko        return
328c1b785cSAlex Zinenko
33b164f23cSAlex Zinenko
348c1b785cSAlex Zinenko# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
358c1b785cSAlex Zinenko# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
368c1b785cSAlex Zinenko# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
378c1b785cSAlex Zinenko# CHECK: scf.yield %[[I1]], %[[I2]]
388c1b785cSAlex Zinenko
398c1b785cSAlex Zinenko
408c1b785cSAlex Zinenko# CHECK-LABEL: TEST: testInductionVar
41b164f23cSAlex Zinenko@constructAndPrintInModule
428c1b785cSAlex Zinenkodef testInductionVar():
438c1b785cSAlex Zinenko    index_type = IndexType.get()
448c1b785cSAlex Zinenko
4536550692SRiver Riddle    @func.FuncOp.from_py_func(index_type, index_type, index_type)
468c1b785cSAlex Zinenko    def induction_var(lb, ub, step):
478c1b785cSAlex Zinenko        loop = scf.ForOp(lb, ub, step, [lb])
488c1b785cSAlex Zinenko        with InsertionPoint(loop.body):
498c1b785cSAlex Zinenko            scf.YieldOp([loop.induction_variable])
508c1b785cSAlex Zinenko        return
518c1b785cSAlex Zinenko
52b164f23cSAlex Zinenko
538c1b785cSAlex Zinenko# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
548c1b785cSAlex Zinenko# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
558c1b785cSAlex Zinenko# CHECK: scf.yield %[[IV]]
56b164f23cSAlex Zinenko
57b164f23cSAlex Zinenko
5827c6d55cSMaksim Levental# CHECK-LABEL: TEST: testForSugar
5927c6d55cSMaksim Levental@constructAndPrintInModule
6027c6d55cSMaksim Leventaldef testForSugar():
6127c6d55cSMaksim Levental    index_type = IndexType.get()
62e9453f3cSMaksim Levental    memref_t = MemRefType.get([10], index_type)
6327c6d55cSMaksim Levental    range = scf.for_
6427c6d55cSMaksim Levental
65e9453f3cSMaksim Levental    # CHECK:  func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
66e9453f3cSMaksim Levental    # CHECK:    scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
67e9453f3cSMaksim Levental    # CHECK:      %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
68e9453f3cSMaksim Levental    # CHECK:      memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex>
6927c6d55cSMaksim Levental    # CHECK:    }
7027c6d55cSMaksim Levental    # CHECK:    return
7127c6d55cSMaksim Levental    # CHECK:  }
72e9453f3cSMaksim Levental    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
73e9453f3cSMaksim Levental    def range_loop_1(lb, ub, step, memref_v):
74e9453f3cSMaksim Levental        for i in range(lb, ub, step):
75e9453f3cSMaksim Levental            add = arith.addi(i, i)
76e9453f3cSMaksim Levental            memref.store(add, memref_v, [i])
77e9453f3cSMaksim Levental
78e9453f3cSMaksim Levental            scf.yield_([])
79e9453f3cSMaksim Levental
80e9453f3cSMaksim Levental    # CHECK:  func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
81e9453f3cSMaksim Levental    # CHECK:    %[[VAL_4:.*]] = arith.constant 10 : index
82e9453f3cSMaksim Levental    # CHECK:    %[[VAL_5:.*]] = arith.constant 1 : index
83e9453f3cSMaksim Levental    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] {
84e9453f3cSMaksim Levental    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
85e9453f3cSMaksim Levental    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
86e9453f3cSMaksim Levental    # CHECK:    }
87e9453f3cSMaksim Levental    # CHECK:    return
88e9453f3cSMaksim Levental    # CHECK:  }
89e9453f3cSMaksim Levental    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
90e9453f3cSMaksim Levental    def range_loop_2(lb, ub, step, memref_v):
91e9453f3cSMaksim Levental        for i in range(lb, 10, 1):
92e9453f3cSMaksim Levental            add = arith.addi(i, i)
93e9453f3cSMaksim Levental            memref.store(add, memref_v, [i])
94e9453f3cSMaksim Levental            scf.yield_([])
95e9453f3cSMaksim Levental
96e9453f3cSMaksim Levental    # CHECK:  func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
97e9453f3cSMaksim Levental    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
98e9453f3cSMaksim Levental    # CHECK:    %[[VAL_5:.*]] = arith.constant 1 : index
99e9453f3cSMaksim Levental    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] {
100e9453f3cSMaksim Levental    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
101e9453f3cSMaksim Levental    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
102e9453f3cSMaksim Levental    # CHECK:    }
103e9453f3cSMaksim Levental    # CHECK:    return
104e9453f3cSMaksim Levental    # CHECK:  }
105e9453f3cSMaksim Levental    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
106e9453f3cSMaksim Levental    def range_loop_3(lb, ub, step, memref_v):
107e9453f3cSMaksim Levental        for i in range(0, ub, 1):
108e9453f3cSMaksim Levental            add = arith.addi(i, i)
109e9453f3cSMaksim Levental            memref.store(add, memref_v, [i])
110e9453f3cSMaksim Levental            scf.yield_([])
111e9453f3cSMaksim Levental
112e9453f3cSMaksim Levental    # CHECK:  func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
113e9453f3cSMaksim Levental    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
114e9453f3cSMaksim Levental    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
115e9453f3cSMaksim Levental    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] {
116e9453f3cSMaksim Levental    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
117e9453f3cSMaksim Levental    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
118e9453f3cSMaksim Levental    # CHECK:    }
119e9453f3cSMaksim Levental    # CHECK:    return
120e9453f3cSMaksim Levental    # CHECK:  }
121e9453f3cSMaksim Levental    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
122e9453f3cSMaksim Levental    def range_loop_4(lb, ub, step, memref_v):
123e9453f3cSMaksim Levental        for i in range(0, 10, step):
124e9453f3cSMaksim Levental            add = arith.addi(i, i)
125e9453f3cSMaksim Levental            memref.store(add, memref_v, [i])
126e9453f3cSMaksim Levental            scf.yield_([])
127e9453f3cSMaksim Levental
128e9453f3cSMaksim Levental    # CHECK:  func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
129e9453f3cSMaksim Levental    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
130e9453f3cSMaksim Levental    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
131e9453f3cSMaksim Levental    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
132e9453f3cSMaksim Levental    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
133e9453f3cSMaksim Levental    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
134e9453f3cSMaksim Levental    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
135e9453f3cSMaksim Levental    # CHECK:    }
136e9453f3cSMaksim Levental    # CHECK:    return
137e9453f3cSMaksim Levental    # CHECK:  }
138e9453f3cSMaksim Levental    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
139e9453f3cSMaksim Levental    def range_loop_5(lb, ub, step, memref_v):
140e9453f3cSMaksim Levental        for i in range(0, 10, 1):
141e9453f3cSMaksim Levental            add = arith.addi(i, i)
142e9453f3cSMaksim Levental            memref.store(add, memref_v, [i])
143e9453f3cSMaksim Levental            scf.yield_([])
144e9453f3cSMaksim Levental
145e9453f3cSMaksim Levental    # CHECK:  func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
146e9453f3cSMaksim Levental    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
147e9453f3cSMaksim Levental    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
148e9453f3cSMaksim Levental    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
149e9453f3cSMaksim Levental    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
150e9453f3cSMaksim Levental    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
151e9453f3cSMaksim Levental    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
152e9453f3cSMaksim Levental    # CHECK:    }
153e9453f3cSMaksim Levental    # CHECK:    return
154e9453f3cSMaksim Levental    # CHECK:  }
155e9453f3cSMaksim Levental    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
156e9453f3cSMaksim Levental    def range_loop_6(lb, ub, step, memref_v):
157e9453f3cSMaksim Levental        for i in range(0, 10):
158e9453f3cSMaksim Levental            add = arith.addi(i, i)
159e9453f3cSMaksim Levental            memref.store(add, memref_v, [i])
160e9453f3cSMaksim Levental            scf.yield_([])
161e9453f3cSMaksim Levental
162e9453f3cSMaksim Levental    # CHECK:  func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
163e9453f3cSMaksim Levental    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
164e9453f3cSMaksim Levental    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
165e9453f3cSMaksim Levental    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
166e9453f3cSMaksim Levental    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
167e9453f3cSMaksim Levental    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
168e9453f3cSMaksim Levental    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
169e9453f3cSMaksim Levental    # CHECK:    }
170e9453f3cSMaksim Levental    # CHECK:    return
171e9453f3cSMaksim Levental    # CHECK:  }
172e9453f3cSMaksim Levental    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
173e9453f3cSMaksim Levental    def range_loop_7(lb, ub, step, memref_v):
174e9453f3cSMaksim Levental        for i in range(10):
175e9453f3cSMaksim Levental            add = arith.addi(i, i)
176e9453f3cSMaksim Levental            memref.store(add, memref_v, [i])
177e9453f3cSMaksim Levental            scf.yield_([])
17827c6d55cSMaksim Levental
1797f58ffd0SGuray Ozen    # CHECK:  func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
1807f58ffd0SGuray Ozen    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
1817f58ffd0SGuray Ozen    # CHECK:    %[[VAL_5:.*]] = arith.constant 0 : index
1827f58ffd0SGuray Ozen    # CHECK:    %[[VAL_6:.*]] = arith.constant 0 : index
1837f58ffd0SGuray Ozen    # CHECK:    %[[VAL_7:.*]] = arith.constant 100 : index
1847f58ffd0SGuray Ozen    # CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
1857f58ffd0SGuray Ozen    # CHECK:    %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
1867f58ffd0SGuray Ozen    # CHECK:      %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
1877f58ffd0SGuray Ozen    # CHECK:      scf.yield %[[VAL_9]] : index
1887f58ffd0SGuray Ozen    # CHECK:    }
1897f58ffd0SGuray Ozen    # CHECK:    memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
1907f58ffd0SGuray Ozen    # CHECK:    return
1917f58ffd0SGuray Ozen    # CHECK:  }
1927f58ffd0SGuray Ozen    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
1937f58ffd0SGuray Ozen    def loop_yield_1(lb, ub, step, memref_v):
1947f58ffd0SGuray Ozen        sum = arith.ConstantOp.create_index(0)
1957f58ffd0SGuray Ozen        c0 = arith.ConstantOp.create_index(0)
1967f58ffd0SGuray Ozen        for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]):
1977f58ffd0SGuray Ozen            loc_sum = arith.addi(loc_sum, i)
1987f58ffd0SGuray Ozen            scf.yield_([loc_sum])
1997f58ffd0SGuray Ozen        memref.store(sum, memref_v, [c0])
2007f58ffd0SGuray Ozen
2017f58ffd0SGuray Ozen    # CHECK:  func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
2027f58ffd0SGuray Ozen    # CHECK:    %[[c0:.*]] = arith.constant 0 : index
2037f58ffd0SGuray Ozen    # CHECK:    %[[c2:.*]] = arith.constant 2 : index
2047f58ffd0SGuray Ozen    # CHECK:    %[[REF1:.*]] = arith.constant 0 : index
2057f58ffd0SGuray Ozen    # CHECK:    %[[REF2:.*]] = arith.constant 1 : index
2067f58ffd0SGuray Ozen    # CHECK:    %[[VAL_6:.*]] = arith.constant 0 : index
2077f58ffd0SGuray Ozen    # CHECK:    %[[VAL_7:.*]] = arith.constant 100 : index
2087f58ffd0SGuray Ozen    # CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
2097f58ffd0SGuray Ozen    # CHECK:    %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
2107f58ffd0SGuray Ozen    # CHECK:      %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
2117f58ffd0SGuray Ozen    # CHECK:      %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
2127f58ffd0SGuray Ozen    # CHECK:      scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
2137f58ffd0SGuray Ozen    # CHECK:    }
2147f58ffd0SGuray Ozen    # CHECK:    return
2157f58ffd0SGuray Ozen    # CHECK:  }
2167f58ffd0SGuray Ozen    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
2177f58ffd0SGuray Ozen    def loop_yield_2(lb, ub, step, memref_v):
2187f58ffd0SGuray Ozen        sum1 = arith.ConstantOp.create_index(0)
2197f58ffd0SGuray Ozen        sum2 = arith.ConstantOp.create_index(2)
2207f58ffd0SGuray Ozen        c0 = arith.ConstantOp.create_index(0)
2217f58ffd0SGuray Ozen        c1 = arith.ConstantOp.create_index(1)
2227f58ffd0SGuray Ozen        for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]):
2237f58ffd0SGuray Ozen            loc_sum1 = arith.addi(loc_sum1, i)
2247f58ffd0SGuray Ozen            loc_sum2 = arith.addi(loc_sum2, i)
2257f58ffd0SGuray Ozen            scf.yield_([loc_sum1, loc_sum2])
2267f58ffd0SGuray Ozen        memref.store(sum1, memref_v, [c0])
2277f58ffd0SGuray Ozen        memref.store(sum2, memref_v, [c1])
2287f58ffd0SGuray Ozen
22927c6d55cSMaksim Levental
230b164f23cSAlex Zinenko@constructAndPrintInModule
231b164f23cSAlex Zinenkodef testOpsAsArguments():
232b164f23cSAlex Zinenko    index_type = IndexType.get()
233f9008e63STobias Hieta    callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private")
23436550692SRiver Riddle    f = func.FuncOp("ops_as_arguments", ([], []))
23523aa5a74SRiver Riddle    with InsertionPoint(f.add_entry_block()):
236a54f4eaeSMogball        lb = arith.ConstantOp.create_index(0)
237a54f4eaeSMogball        ub = arith.ConstantOp.create_index(42)
238a54f4eaeSMogball        step = arith.ConstantOp.create_index(2)
23923aa5a74SRiver Riddle        iter_args = func.CallOp(callee, [])
240b164f23cSAlex Zinenko        loop = scf.ForOp(lb, ub, step, iter_args)
241b164f23cSAlex Zinenko        with InsertionPoint(loop.body):
242b164f23cSAlex Zinenko            scf.YieldOp(loop.inner_iter_args)
24323aa5a74SRiver Riddle        func.ReturnOp([])
244b164f23cSAlex Zinenko
245b164f23cSAlex Zinenko
246b164f23cSAlex Zinenko# CHECK-LABEL: TEST: testOpsAsArguments
247b164f23cSAlex Zinenko# CHECK: func private @callee() -> (index, index)
248b164f23cSAlex Zinenko# CHECK: func @ops_as_arguments() {
249a54f4eaeSMogball# CHECK:   %[[LB:.*]] = arith.constant 0
250a54f4eaeSMogball# CHECK:   %[[UB:.*]] = arith.constant 42
251a54f4eaeSMogball# CHECK:   %[[STEP:.*]] = arith.constant 2
252b164f23cSAlex Zinenko# CHECK:   %[[ARGS:.*]]:2 = call @callee()
253b164f23cSAlex Zinenko# CHECK:   scf.for %arg0 = %c0 to %c42 step %c2
254b164f23cSAlex Zinenko# CHECK:   iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
255b164f23cSAlex Zinenko# CHECK:     scf.yield %{{.*}}, %{{.*}}
256b164f23cSAlex Zinenko# CHECK:   return
257036088fdSchhzh123
258036088fdSchhzh123
259036088fdSchhzh123@constructAndPrintInModule
260036088fdSchhzh123def testIfWithoutElse():
261036088fdSchhzh123    bool = IntegerType.get_signless(1)
262036088fdSchhzh123    i32 = IntegerType.get_signless(32)
263036088fdSchhzh123
26436550692SRiver Riddle    @func.FuncOp.from_py_func(bool)
265036088fdSchhzh123    def simple_if(cond):
266036088fdSchhzh123        if_op = scf.IfOp(cond)
267036088fdSchhzh123        with InsertionPoint(if_op.then_block):
268036088fdSchhzh123            one = arith.ConstantOp(i32, 1)
269036088fdSchhzh123            add = arith.AddIOp(one, one)
270036088fdSchhzh123            scf.YieldOp([])
271036088fdSchhzh123        return
272036088fdSchhzh123
273036088fdSchhzh123
274036088fdSchhzh123# CHECK: func @simple_if(%[[ARG0:.*]]: i1)
275036088fdSchhzh123# CHECK: scf.if %[[ARG0:.*]]
276036088fdSchhzh123# CHECK:   %[[ONE:.*]] = arith.constant 1
277036088fdSchhzh123# CHECK:   %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
278036088fdSchhzh123# CHECK: return
279036088fdSchhzh123
280036088fdSchhzh123
281036088fdSchhzh123@constructAndPrintInModule
282*ad89e617SMatt Hofmanndef testNestedIf():
283*ad89e617SMatt Hofmann    bool = IntegerType.get_signless(1)
284*ad89e617SMatt Hofmann    i32 = IntegerType.get_signless(32)
285*ad89e617SMatt Hofmann
286*ad89e617SMatt Hofmann    @func.FuncOp.from_py_func(bool, bool)
287*ad89e617SMatt Hofmann    def nested_if(b, c):
288*ad89e617SMatt Hofmann        if_op = scf.IfOp(b)
289*ad89e617SMatt Hofmann        with InsertionPoint(if_op.then_block) as ip:
290*ad89e617SMatt Hofmann            if_op = scf.IfOp(c, ip=ip)
291*ad89e617SMatt Hofmann            with InsertionPoint(if_op.then_block):
292*ad89e617SMatt Hofmann                one = arith.ConstantOp(i32, 1)
293*ad89e617SMatt Hofmann                add = arith.AddIOp(one, one)
294*ad89e617SMatt Hofmann                scf.YieldOp([])
295*ad89e617SMatt Hofmann            scf.YieldOp([])
296*ad89e617SMatt Hofmann        return
297*ad89e617SMatt Hofmann
298*ad89e617SMatt Hofmann
299*ad89e617SMatt Hofmann# CHECK: func @nested_if(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
300*ad89e617SMatt Hofmann# CHECK: scf.if %[[ARG0:.*]]
301*ad89e617SMatt Hofmann# CHECK:   scf.if %[[ARG1:.*]]
302*ad89e617SMatt Hofmann# CHECK:     %[[ONE:.*]] = arith.constant 1
303*ad89e617SMatt Hofmann# CHECK:     %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
304*ad89e617SMatt Hofmann# CHECK: return
305*ad89e617SMatt Hofmann
306*ad89e617SMatt Hofmann
307*ad89e617SMatt Hofmann@constructAndPrintInModule
308036088fdSchhzh123def testIfWithElse():
309036088fdSchhzh123    bool = IntegerType.get_signless(1)
310036088fdSchhzh123    i32 = IntegerType.get_signless(32)
311036088fdSchhzh123
31236550692SRiver Riddle    @func.FuncOp.from_py_func(bool)
313036088fdSchhzh123    def simple_if_else(cond):
314036088fdSchhzh123        if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
315036088fdSchhzh123        with InsertionPoint(if_op.then_block):
316036088fdSchhzh123            x_true = arith.ConstantOp(i32, 0)
317036088fdSchhzh123            y_true = arith.ConstantOp(i32, 1)
318036088fdSchhzh123            scf.YieldOp([x_true, y_true])
319036088fdSchhzh123        with InsertionPoint(if_op.else_block):
320036088fdSchhzh123            x_false = arith.ConstantOp(i32, 2)
321036088fdSchhzh123            y_false = arith.ConstantOp(i32, 3)
322036088fdSchhzh123            scf.YieldOp([x_false, y_false])
323036088fdSchhzh123        add = arith.AddIOp(if_op.results[0], if_op.results[1])
324036088fdSchhzh123        return
325036088fdSchhzh123
326036088fdSchhzh123
327036088fdSchhzh123# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
328036088fdSchhzh123# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
329036088fdSchhzh123# CHECK:   %[[ZERO:.*]] = arith.constant 0
330036088fdSchhzh123# CHECK:   %[[ONE:.*]] = arith.constant 1
331036088fdSchhzh123# CHECK:   scf.yield %[[ZERO]], %[[ONE]]
332036088fdSchhzh123# CHECK: } else {
333036088fdSchhzh123# CHECK:   %[[TWO:.*]] = arith.constant 2
334036088fdSchhzh123# CHECK:   %[[THREE:.*]] = arith.constant 3
335036088fdSchhzh123# CHECK:   scf.yield %[[TWO]], %[[THREE]]
336036088fdSchhzh123# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
337036088fdSchhzh123# CHECK: return
338