xref: /llvm-project/mlir/test/python/dialects/affine.py (revision 31aa7f34e07c901773993dac0f33568307f96da6)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import func
5from mlir.dialects import arith
6from mlir.dialects import memref
7from mlir.dialects import affine
8import mlir.extras.types as T
9
10
11def constructAndPrintInModule(f):
12    print("\nTEST:", f.__name__)
13    with Context(), Location.unknown():
14        module = Module.create()
15        with InsertionPoint(module.body):
16            f()
17        print(module)
18    return f
19
20
21# CHECK-LABEL: TEST: testAffineStoreOp
22@constructAndPrintInModule
23def testAffineStoreOp():
24    f32 = F32Type.get()
25    index_type = IndexType.get()
26    memref_type_out = MemRefType.get([12, 12], f32)
27
28    # CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
29    @func.FuncOp.from_py_func(index_type)
30    def affine_store_test(arg0):
31        # CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
32        mem = memref.AllocOp(memref_type_out, [], []).result
33
34        d0 = AffineDimExpr.get(0)
35        s0 = AffineSymbolExpr.get(0)
36        map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
37
38        # CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
39        a1 = arith.ConstantOp(f32, 2.1)
40
41        # CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
42        affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
43
44        return mem
45
46
47# CHECK-LABEL: TEST: testAffineDelinearizeInfer
48@constructAndPrintInModule
49def testAffineDelinearizeInfer():
50    # CHECK: %[[C1:.*]] = arith.constant 1 : index
51    c1 = arith.ConstantOp(T.index(), 1)
52    # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (2, 3) : index, index
53    two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [2, 3])
54
55
56# CHECK-LABEL: TEST: testAffineLoadOp
57@constructAndPrintInModule
58def testAffineLoadOp():
59    f32 = F32Type.get()
60    index_type = IndexType.get()
61    memref_type_in = MemRefType.get([10, 10], f32)
62
63    # CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 {
64    @func.FuncOp.from_py_func(memref_type_in, index_type)
65    def affine_load_test(I, arg0):
66        d0 = AffineDimExpr.get(0)
67        s0 = AffineSymbolExpr.get(0)
68        map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
69
70        # CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32>
71        a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map)
72
73        return a1
74
75
76# CHECK-LABEL: TEST: testAffineForOp
77@constructAndPrintInModule
78def testAffineForOp():
79    f32 = F32Type.get()
80    index_type = IndexType.get()
81    memref_type = MemRefType.get([1024], f32)
82
83    # CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)>
84    # CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)>
85    # CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) {
86    @func.FuncOp.from_py_func(memref_type)
87    def affine_for_op_test(buffer):
88        # CHECK: %[[C1:.*]] = arith.constant 1 : index
89        c1 = arith.ConstantOp(index_type, 1)
90        # CHECK: %[[C2:.*]] = arith.constant 2 : index
91        c2 = arith.ConstantOp(index_type, 2)
92        # CHECK: %[[C3:.*]] = arith.constant 3 : index
93        c3 = arith.ConstantOp(index_type, 3)
94        # CHECK: %[[C9:.*]] = arith.constant 9 : index
95        c9 = arith.ConstantOp(index_type, 9)
96        # CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32
97        ac0 = AffineConstantExpr.get(0)
98
99        d0 = AffineDimExpr.get(0)
100        d1 = AffineDimExpr.get(1)
101        s0 = AffineSymbolExpr.get(0)
102        lb = AffineMap.get(1, 1, [ac0, d0 + s0])
103        ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1])
104        sum_0 = arith.ConstantOp(f32, 0.0)
105
106        # CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) {
107        sum = affine.AffineForOp(
108            lb,
109            ub,
110            2,
111            iter_args=[sum_0],
112            lower_bound_operands=[c2, c3],
113            upper_bound_operands=[c9, c1],
114        )
115
116        with InsertionPoint(sum.body):
117            # CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
118            tmp = memref.LoadOp(buffer, [sum.induction_variable])
119            sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
120            affine.AffineYieldOp([sum_next])
121
122
123# CHECK-LABEL: TEST: testAffineForOpErrors
124@constructAndPrintInModule
125def testAffineForOpErrors():
126    c1 = arith.ConstantOp(T.index(), 1)
127    c2 = arith.ConstantOp(T.index(), 2)
128    c3 = arith.ConstantOp(T.index(), 3)
129    d0 = AffineDimExpr.get(0)
130
131    try:
132        affine.AffineForOp(
133            c1,
134            c2,
135            1,
136            lower_bound_operands=[c3],
137            upper_bound_operands=[],
138        )
139    except ValueError as e:
140        assert (
141            e.args[0]
142            == "Either a concrete lower bound or an AffineMap in combination with lower bound operands, but not both, is supported."
143        )
144
145    try:
146        affine.AffineForOp(
147            AffineMap.get_constant(1),
148            c2,
149            1,
150            lower_bound_operands=[c3, c3],
151            upper_bound_operands=[],
152        )
153    except ValueError as e:
154        assert (
155            e.args[0]
156            == "Wrong number of lower bound operands passed to AffineForOp; Expected 0, got 2."
157        )
158
159    try:
160        two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [1, 1])
161        affine.AffineForOp(
162            two_indices,
163            c2,
164            1,
165            lower_bound_operands=[],
166            upper_bound_operands=[],
167        )
168    except ValueError as e:
169        assert e.args[0] == "Only a single concrete value is supported for lower bound."
170
171    try:
172        affine.AffineForOp(
173            1.0,
174            c2,
175            1,
176            lower_bound_operands=[],
177            upper_bound_operands=[],
178        )
179    except ValueError as e:
180        assert e.args[0] == "lower bound must be int | ResultValueT | AffineMap."
181
182
183@constructAndPrintInModule
184def testForSugar():
185    memref_t = T.memref(10, T.index())
186    range = affine.for_
187
188    # CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> (d0)>
189
190    # CHECK-LABEL:   func.func @range_loop_1(
191    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
192    # CHECK:           affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to #[[$ATTR_2]](%[[VAL_1]]) {
193    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
194    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
195    # CHECK:           }
196    # CHECK:           return
197    # CHECK:         }
198    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
199    def range_loop_1(lb, ub, memref_v):
200        for i in range(lb, ub, step=1):
201            add = arith.addi(i, i)
202            memref.store(add, memref_v, [i])
203
204            affine.yield_([])
205
206    # CHECK-LABEL:   func.func @range_loop_2(
207    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
208    # CHECK:           affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to 10 {
209    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
210    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
211    # CHECK:           }
212    # CHECK:           return
213    # CHECK:         }
214    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
215    def range_loop_2(lb, ub, memref_v):
216        for i in range(lb, 10, step=1):
217            add = arith.addi(i, i)
218            memref.store(add, memref_v, [i])
219            affine.yield_([])
220
221    # CHECK-LABEL:   func.func @range_loop_3(
222    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
223    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to #[[$ATTR_2]](%[[VAL_1]]) {
224    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
225    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
226    # CHECK:           }
227    # CHECK:           return
228    # CHECK:         }
229    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
230    def range_loop_3(lb, ub, memref_v):
231        for i in range(0, ub, step=1):
232            add = arith.addi(i, i)
233            memref.store(add, memref_v, [i])
234            affine.yield_([])
235
236    # CHECK-LABEL:   func.func @range_loop_4(
237    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
238    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
239    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
240    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
241    # CHECK:           }
242    # CHECK:           return
243    # CHECK:         }
244    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
245    def range_loop_4(lb, ub, memref_v):
246        for i in range(0, 10, step=1):
247            add = arith.addi(i, i)
248            memref.store(add, memref_v, [i])
249            affine.yield_([])
250
251    # CHECK-LABEL:   func.func @range_loop_8(
252    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
253    # CHECK:           %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
254    # CHECK:             %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
255    # CHECK:             memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
256    # CHECK:             affine.yield %[[VAL_5]] : memref<10xindex>
257    # CHECK:           }
258    # CHECK:           return
259    # CHECK:         }
260    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
261    def range_loop_8(lb, ub, memref_v):
262        for i, it in range(0, 10, iter_args=[memref_v]):
263            add = arith.addi(i, i)
264            memref.store(add, it, [i])
265            affine.yield_([it])
266
267
268# CHECK-LABEL: TEST: testAffineIfWithoutElse
269@constructAndPrintInModule
270def testAffineIfWithoutElse():
271    index = IndexType.get()
272    i32 = IntegerType.get_signless(32)
273    d0 = AffineDimExpr.get(0)
274
275    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
276    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
277
278    # CHECK-LABEL:  func.func @simple_affine_if(
279    # CHECK-SAME:                        %[[VAL_0:.*]]: index) {
280    # CHECK:            affine.if #[[$SET0]](%[[VAL_0]]) {
281    # CHECK:                %[[VAL_1:.*]] = arith.constant 1 : i32
282    # CHECK:                %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
283    # CHECK:            }
284    # CHECK:            return
285    # CHECK:        }
286    @func.FuncOp.from_py_func(index)
287    def simple_affine_if(cond_operands):
288        if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
289        with InsertionPoint(if_op.then_block):
290            one = arith.ConstantOp(i32, 1)
291            add = arith.AddIOp(one, one)
292            affine.AffineYieldOp([])
293        return
294
295
296# CHECK-LABEL: TEST: testAffineIfWithElse
297@constructAndPrintInModule
298def testAffineIfWithElse():
299    index = IndexType.get()
300    i32 = IntegerType.get_signless(32)
301    d0 = AffineDimExpr.get(0)
302
303    # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
304    cond = IntegerSet.get(1, 0, [d0 - 5], [False])
305
306    # CHECK-LABEL:  func.func @simple_affine_if_else(
307    # CHECK-SAME:                                    %[[VAL_0:.*]]: index) {
308    # CHECK:            %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
309    # CHECK:                %[[VAL_XT:.*]] = arith.constant 0 : i32
310    # CHECK:                %[[VAL_YT:.*]] = arith.constant 1 : i32
311    # CHECK:                affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
312    # CHECK:            } else {
313    # CHECK:                %[[VAL_XF:.*]] = arith.constant 2 : i32
314    # CHECK:                %[[VAL_YF:.*]] = arith.constant 3 : i32
315    # CHECK:                affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
316    # CHECK:            }
317    # CHECK:            %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
318    # CHECK:            return
319    # CHECK:        }
320
321    @func.FuncOp.from_py_func(index)
322    def simple_affine_if_else(cond_operands):
323        if_op = affine.AffineIfOp(
324            cond, [i32, i32], cond_operands=[cond_operands], has_else=True
325        )
326        with InsertionPoint(if_op.then_block):
327            x_true = arith.ConstantOp(i32, 0)
328            y_true = arith.ConstantOp(i32, 1)
329            affine.AffineYieldOp([x_true, y_true])
330        with InsertionPoint(if_op.else_block):
331            x_false = arith.ConstantOp(i32, 2)
332            y_false = arith.ConstantOp(i32, 3)
333            affine.AffineYieldOp([x_false, y_false])
334        add = arith.AddIOp(if_op.results[0], if_op.results[1])
335        return
336