xref: /llvm-project/mlir/test/python/dialects/scf.py (revision ad89e617c703239518187912540b8ea811dc2eda)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import arith
5from mlir.dialects import func
6from mlir.dialects import memref
7from mlir.dialects import scf
8from mlir.passmanager import PassManager
9
10
11def constructAndPrintInModule(f):
12    print("\nTEST:", f.__name__)
13    with Context(), Location.unknown():
14        module = Module.create()
15        with InsertionPoint(module.body):
16            f()
17        print(module)
18    return f
19
20
21# CHECK-LABEL: TEST: testSimpleLoop
22@constructAndPrintInModule
23def testSimpleLoop():
24    index_type = IndexType.get()
25
26    @func.FuncOp.from_py_func(index_type, index_type, index_type)
27    def simple_loop(lb, ub, step):
28        loop = scf.ForOp(lb, ub, step, [lb, lb])
29        with InsertionPoint(loop.body):
30            scf.YieldOp(loop.inner_iter_args)
31        return
32
33
34# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
35# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
36# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
37# CHECK: scf.yield %[[I1]], %[[I2]]
38
39
40# CHECK-LABEL: TEST: testInductionVar
41@constructAndPrintInModule
42def testInductionVar():
43    index_type = IndexType.get()
44
45    @func.FuncOp.from_py_func(index_type, index_type, index_type)
46    def induction_var(lb, ub, step):
47        loop = scf.ForOp(lb, ub, step, [lb])
48        with InsertionPoint(loop.body):
49            scf.YieldOp([loop.induction_variable])
50        return
51
52
53# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
54# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
55# CHECK: scf.yield %[[IV]]
56
57
58# CHECK-LABEL: TEST: testForSugar
59@constructAndPrintInModule
60def testForSugar():
61    index_type = IndexType.get()
62    memref_t = MemRefType.get([10], index_type)
63    range = scf.for_
64
65    # CHECK:  func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
66    # CHECK:    scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
67    # CHECK:      %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
68    # CHECK:      memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex>
69    # CHECK:    }
70    # CHECK:    return
71    # CHECK:  }
72    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
73    def range_loop_1(lb, ub, step, memref_v):
74        for i in range(lb, ub, step):
75            add = arith.addi(i, i)
76            memref.store(add, memref_v, [i])
77
78            scf.yield_([])
79
80    # CHECK:  func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
81    # CHECK:    %[[VAL_4:.*]] = arith.constant 10 : index
82    # CHECK:    %[[VAL_5:.*]] = arith.constant 1 : index
83    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] {
84    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
85    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
86    # CHECK:    }
87    # CHECK:    return
88    # CHECK:  }
89    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
90    def range_loop_2(lb, ub, step, memref_v):
91        for i in range(lb, 10, 1):
92            add = arith.addi(i, i)
93            memref.store(add, memref_v, [i])
94            scf.yield_([])
95
96    # CHECK:  func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
97    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
98    # CHECK:    %[[VAL_5:.*]] = arith.constant 1 : index
99    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] {
100    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
101    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
102    # CHECK:    }
103    # CHECK:    return
104    # CHECK:  }
105    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
106    def range_loop_3(lb, ub, step, memref_v):
107        for i in range(0, ub, 1):
108            add = arith.addi(i, i)
109            memref.store(add, memref_v, [i])
110            scf.yield_([])
111
112    # CHECK:  func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
113    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
114    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
115    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] {
116    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
117    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
118    # CHECK:    }
119    # CHECK:    return
120    # CHECK:  }
121    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
122    def range_loop_4(lb, ub, step, memref_v):
123        for i in range(0, 10, step):
124            add = arith.addi(i, i)
125            memref.store(add, memref_v, [i])
126            scf.yield_([])
127
128    # CHECK:  func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
129    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
130    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
131    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
132    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
133    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
134    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
135    # CHECK:    }
136    # CHECK:    return
137    # CHECK:  }
138    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
139    def range_loop_5(lb, ub, step, memref_v):
140        for i in range(0, 10, 1):
141            add = arith.addi(i, i)
142            memref.store(add, memref_v, [i])
143            scf.yield_([])
144
145    # CHECK:  func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
146    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
147    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
148    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
149    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
150    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
151    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
152    # CHECK:    }
153    # CHECK:    return
154    # CHECK:  }
155    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
156    def range_loop_6(lb, ub, step, memref_v):
157        for i in range(0, 10):
158            add = arith.addi(i, i)
159            memref.store(add, memref_v, [i])
160            scf.yield_([])
161
162    # CHECK:  func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
163    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
164    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
165    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
166    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
167    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
168    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
169    # CHECK:    }
170    # CHECK:    return
171    # CHECK:  }
172    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
173    def range_loop_7(lb, ub, step, memref_v):
174        for i in range(10):
175            add = arith.addi(i, i)
176            memref.store(add, memref_v, [i])
177            scf.yield_([])
178
179    # CHECK:  func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
180    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
181    # CHECK:    %[[VAL_5:.*]] = arith.constant 0 : index
182    # CHECK:    %[[VAL_6:.*]] = arith.constant 0 : index
183    # CHECK:    %[[VAL_7:.*]] = arith.constant 100 : index
184    # CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
185    # CHECK:    %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
186    # CHECK:      %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
187    # CHECK:      scf.yield %[[VAL_9]] : index
188    # CHECK:    }
189    # CHECK:    memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
190    # CHECK:    return
191    # CHECK:  }
192    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
193    def loop_yield_1(lb, ub, step, memref_v):
194        sum = arith.ConstantOp.create_index(0)
195        c0 = arith.ConstantOp.create_index(0)
196        for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]):
197            loc_sum = arith.addi(loc_sum, i)
198            scf.yield_([loc_sum])
199        memref.store(sum, memref_v, [c0])
200
201    # CHECK:  func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
202    # CHECK:    %[[c0:.*]] = arith.constant 0 : index
203    # CHECK:    %[[c2:.*]] = arith.constant 2 : index
204    # CHECK:    %[[REF1:.*]] = arith.constant 0 : index
205    # CHECK:    %[[REF2:.*]] = arith.constant 1 : index
206    # CHECK:    %[[VAL_6:.*]] = arith.constant 0 : index
207    # CHECK:    %[[VAL_7:.*]] = arith.constant 100 : index
208    # CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
209    # CHECK:    %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
210    # CHECK:      %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
211    # CHECK:      %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
212    # CHECK:      scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
213    # CHECK:    }
214    # CHECK:    return
215    # CHECK:  }
216    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
217    def loop_yield_2(lb, ub, step, memref_v):
218        sum1 = arith.ConstantOp.create_index(0)
219        sum2 = arith.ConstantOp.create_index(2)
220        c0 = arith.ConstantOp.create_index(0)
221        c1 = arith.ConstantOp.create_index(1)
222        for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]):
223            loc_sum1 = arith.addi(loc_sum1, i)
224            loc_sum2 = arith.addi(loc_sum2, i)
225            scf.yield_([loc_sum1, loc_sum2])
226        memref.store(sum1, memref_v, [c0])
227        memref.store(sum2, memref_v, [c1])
228
229
230@constructAndPrintInModule
231def testOpsAsArguments():
232    index_type = IndexType.get()
233    callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private")
234    f = func.FuncOp("ops_as_arguments", ([], []))
235    with InsertionPoint(f.add_entry_block()):
236        lb = arith.ConstantOp.create_index(0)
237        ub = arith.ConstantOp.create_index(42)
238        step = arith.ConstantOp.create_index(2)
239        iter_args = func.CallOp(callee, [])
240        loop = scf.ForOp(lb, ub, step, iter_args)
241        with InsertionPoint(loop.body):
242            scf.YieldOp(loop.inner_iter_args)
243        func.ReturnOp([])
244
245
246# CHECK-LABEL: TEST: testOpsAsArguments
247# CHECK: func private @callee() -> (index, index)
248# CHECK: func @ops_as_arguments() {
249# CHECK:   %[[LB:.*]] = arith.constant 0
250# CHECK:   %[[UB:.*]] = arith.constant 42
251# CHECK:   %[[STEP:.*]] = arith.constant 2
252# CHECK:   %[[ARGS:.*]]:2 = call @callee()
253# CHECK:   scf.for %arg0 = %c0 to %c42 step %c2
254# CHECK:   iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
255# CHECK:     scf.yield %{{.*}}, %{{.*}}
256# CHECK:   return
257
258
259@constructAndPrintInModule
260def testIfWithoutElse():
261    bool = IntegerType.get_signless(1)
262    i32 = IntegerType.get_signless(32)
263
264    @func.FuncOp.from_py_func(bool)
265    def simple_if(cond):
266        if_op = scf.IfOp(cond)
267        with InsertionPoint(if_op.then_block):
268            one = arith.ConstantOp(i32, 1)
269            add = arith.AddIOp(one, one)
270            scf.YieldOp([])
271        return
272
273
274# CHECK: func @simple_if(%[[ARG0:.*]]: i1)
275# CHECK: scf.if %[[ARG0:.*]]
276# CHECK:   %[[ONE:.*]] = arith.constant 1
277# CHECK:   %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
278# CHECK: return
279
280
281@constructAndPrintInModule
282def testNestedIf():
283    bool = IntegerType.get_signless(1)
284    i32 = IntegerType.get_signless(32)
285
286    @func.FuncOp.from_py_func(bool, bool)
287    def nested_if(b, c):
288        if_op = scf.IfOp(b)
289        with InsertionPoint(if_op.then_block) as ip:
290            if_op = scf.IfOp(c, ip=ip)
291            with InsertionPoint(if_op.then_block):
292                one = arith.ConstantOp(i32, 1)
293                add = arith.AddIOp(one, one)
294                scf.YieldOp([])
295            scf.YieldOp([])
296        return
297
298
299# CHECK: func @nested_if(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
300# CHECK: scf.if %[[ARG0:.*]]
301# CHECK:   scf.if %[[ARG1:.*]]
302# CHECK:     %[[ONE:.*]] = arith.constant 1
303# CHECK:     %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
304# CHECK: return
305
306
307@constructAndPrintInModule
308def testIfWithElse():
309    bool = IntegerType.get_signless(1)
310    i32 = IntegerType.get_signless(32)
311
312    @func.FuncOp.from_py_func(bool)
313    def simple_if_else(cond):
314        if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
315        with InsertionPoint(if_op.then_block):
316            x_true = arith.ConstantOp(i32, 0)
317            y_true = arith.ConstantOp(i32, 1)
318            scf.YieldOp([x_true, y_true])
319        with InsertionPoint(if_op.else_block):
320            x_false = arith.ConstantOp(i32, 2)
321            y_false = arith.ConstantOp(i32, 3)
322            scf.YieldOp([x_false, y_false])
323        add = arith.AddIOp(if_op.results[0], if_op.results[1])
324        return
325
326
327# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
328# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
329# CHECK:   %[[ZERO:.*]] = arith.constant 0
330# CHECK:   %[[ONE:.*]] = arith.constant 1
331# CHECK:   scf.yield %[[ZERO]], %[[ONE]]
332# CHECK: } else {
333# CHECK:   %[[TWO:.*]] = arith.constant 2
334# CHECK:   %[[THREE:.*]] = arith.constant 3
335# CHECK:   scf.yield %[[TWO]], %[[THREE]]
336# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
337# CHECK: return
338