xref: /llvm-project/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir (revision 644cd0724d8826c08c151b2e4ada00e694eb6cbf)
1// RUN: mlir-opt %s -scf-for-loop-canonicalization -split-input-file | FileCheck %s
2
3// CHECK-LABEL: func @scf_for_canonicalize_min
4//       CHECK:   %[[C2:.*]] = arith.constant 2 : i64
5//       CHECK:   scf.for
6//       CHECK:     memref.store %[[C2]], %{{.*}}[] : memref<i64>
7func.func @scf_for_canonicalize_min(%A : memref<i64>) {
8  %c0 = arith.constant 0 : index
9  %c2 = arith.constant 2 : index
10  %c4 = arith.constant 4 : index
11
12  scf.for %i = %c0 to %c4 step %c2 {
13    %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4)
14    %2 = arith.index_cast %1: index to i64
15    memref.store %2, %A[]: memref<i64>
16  }
17  return
18}
19
20// -----
21
22// CHECK-LABEL: func @scf_for_canonicalize_max
23//       CHECK:   %[[Cneg2:.*]] = arith.constant -2 : i64
24//       CHECK:   scf.for
25//       CHECK:     memref.store %[[Cneg2]], %{{.*}}[] : memref<i64>
26func.func @scf_for_canonicalize_max(%A : memref<i64>) {
27  %c0 = arith.constant 0 : index
28  %c2 = arith.constant 2 : index
29  %c4 = arith.constant 4 : index
30
31  scf.for %i = %c0 to %c4 step %c2 {
32    %1 = affine.max affine_map<(d0, d1)[] -> (-2, -(d1 - d0))> (%i, %c4)
33    %2 = arith.index_cast %1: index to i64
34    memref.store %2, %A[]: memref<i64>
35  }
36  return
37}
38
39// -----
40
41// CHECK-LABEL: func @scf_for_max_not_canonicalizable
42//       CHECK:   scf.for
43//       CHECK:     affine.max
44//       CHECK:     arith.index_cast
45func.func @scf_for_max_not_canonicalizable(%A : memref<i64>) {
46  %c0 = arith.constant 0 : index
47  %c2 = arith.constant 2 : index
48  %c3 = arith.constant 3 : index
49  %c4 = arith.constant 4 : index
50
51  scf.for %i = %c0 to %c4 step %c2 {
52    %1 = affine.max affine_map<(d0, d1)[] -> (-2, -(d1 - d0))> (%i, %c3)
53    %2 = arith.index_cast %1: index to i64
54    memref.store %2, %A[]: memref<i64>
55  }
56  return
57}
58
59// -----
60
61// CHECK-LABEL: func @scf_for_loop_nest_canonicalize_min
62//       CHECK:   %[[C5:.*]] = arith.constant 5 : i64
63//       CHECK:   scf.for
64//       CHECK:     scf.for
65//       CHECK:       memref.store %[[C5]], %{{.*}}[] : memref<i64>
66func.func @scf_for_loop_nest_canonicalize_min(%A : memref<i64>) {
67  %c0 = arith.constant 0 : index
68  %c2 = arith.constant 2 : index
69  %c3 = arith.constant 3 : index
70  %c4 = arith.constant 4 : index
71  %c6 = arith.constant 6 : index
72
73  scf.for %i = %c0 to %c4 step %c2 {
74    scf.for %j = %c0 to %c6 step %c3 {
75      %1 = affine.min affine_map<(d0, d1, d2, d3)[] -> (5, d1 + d3 - d0 - d2)> (%i, %c4, %j, %c6)
76      %2 = arith.index_cast %1: index to i64
77      memref.store %2, %A[]: memref<i64>
78    }
79  }
80  return
81}
82
83// -----
84
85// CHECK-LABEL: func @scf_for_not_canonicalizable_1
86//       CHECK:   scf.for
87//       CHECK:     affine.min
88//       CHECK:     arith.index_cast
89func.func @scf_for_not_canonicalizable_1(%A : memref<i64>) {
90  // This should not canonicalize because: 4 - %i may take the value 1 < 2.
91  %c1 = arith.constant 1 : index
92  %c2 = arith.constant 2 : index
93  %c4 = arith.constant 4 : index
94
95  scf.for %i = %c1 to %c4 step %c2 {
96    %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c4]
97    %2 = arith.index_cast %1: index to i64
98    memref.store %2, %A[]: memref<i64>
99  }
100  return
101}
102
103// -----
104
105// CHECK-LABEL: func @scf_for_canonicalize_partly
106//       CHECK:   scf.for
107//       CHECK:     affine.apply
108//       CHECK:     arith.index_cast
109func.func @scf_for_canonicalize_partly(%A : memref<i64>) {
110  // This should canonicalize only partly: 256 - %i <= 256.
111  %c1 = arith.constant 1 : index
112  %c16 = arith.constant 16 : index
113  %c256 = arith.constant 256 : index
114
115  scf.for %i = %c1 to %c256 step %c16 {
116    %1 = affine.min affine_map<(d0) -> (256, 256 - d0)> (%i)
117    %2 = arith.index_cast %1: index to i64
118    memref.store %2, %A[]: memref<i64>
119  }
120  return
121}
122
123// -----
124
125// CHECK-LABEL: func @scf_for_not_canonicalizable_2
126//       CHECK: scf.for
127//       CHECK:   affine.min
128//       CHECK:   arith.index_cast
129func.func @scf_for_not_canonicalizable_2(%A : memref<i64>, %step : index) {
130  // This example should simplify but affine_map is currently missing
131  // semi-affine canonicalizations: `((s0 * 42 - 1) floordiv s0) * s0`
132  // should evaluate to 41 * s0.
133  // Note that this may require positivity assumptions on `s0`.
134  // Revisit when support is added.
135  %c0 = arith.constant 0 : index
136
137  %ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step)
138  scf.for %i = %c0 to %ub step %step {
139    %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d1 - d2)> (%step, %ub, %i)
140    %2 = arith.index_cast %1: index to i64
141    memref.store %2, %A[]: memref<i64>
142  }
143  return
144}
145
146// -----
147
148// CHECK-LABEL: func @scf_for_not_canonicalizable_3
149//       CHECK: scf.for
150//       CHECK:   affine.min
151//       CHECK:   arith.index_cast
152func.func @scf_for_not_canonicalizable_3(%A : memref<i64>, %step : index) {
153  // This example should simplify but affine_map is currently missing
154  // semi-affine canonicalizations: `-(((s0 * s0 - 1) floordiv s0) * s0)`
155  // should evaluate to (s0 - 1) * s0.
156  // Note that this may require positivity assumptions on `s0`.
157  // Revisit when support is added.
158  %c0 = arith.constant 0 : index
159
160  %ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step]
161  scf.for %i = %c0 to %ub2 step %step {
162    %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2)
163    %2 = arith.index_cast %1: index to i64
164    memref.store %2, %A[]: memref<i64>
165  }
166  return
167}
168
169// -----
170
171// CHECK-LABEL: func @scf_for_invalid_loop
172//       CHECK: scf.for
173//       CHECK:   affine.min
174//       CHECK:   arith.index_cast
175func.func @scf_for_invalid_loop(%A : memref<i64>, %step : index) {
176  // This is an invalid loop. It should not be touched by the canonicalization
177  // pattern.
178  %c1 = arith.constant 1 : index
179  %c7 = arith.constant 7 : index
180  %c256 = arith.constant 256 : index
181
182  scf.for %i = %c256 to %c1 step %c1 {
183    %1 = affine.min affine_map<(d0)[s0] -> (s0 + d0, 0)> (%i)[%c7]
184    %2 = arith.index_cast %1: index to i64
185    memref.store %2, %A[]: memref<i64>
186  }
187  return
188}
189
190// -----
191
192// CHECK-LABEL: func @scf_parallel_canonicalize_min_1
193//       CHECK:   %[[C2:.*]] = arith.constant 2 : i64
194//       CHECK:   scf.parallel
195//  CHECK-NEXT:     memref.store %[[C2]], %{{.*}}[] : memref<i64>
196func.func @scf_parallel_canonicalize_min_1(%A : memref<i64>) {
197  %c0 = arith.constant 0 : index
198  %c2 = arith.constant 2 : index
199  %c4 = arith.constant 4 : index
200
201  scf.parallel (%i) = (%c0) to (%c4) step (%c2) {
202    %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4)
203    %2 = arith.index_cast %1: index to i64
204    memref.store %2, %A[]: memref<i64>
205  }
206  return
207}
208
209// -----
210
211// CHECK-LABEL: func @scf_parallel_canonicalize_min_2
212//       CHECK:   %[[C2:.*]] = arith.constant 2 : i64
213//       CHECK:   scf.parallel
214//  CHECK-NEXT:     memref.store %[[C2]], %{{.*}}[] : memref<i64>
215func.func @scf_parallel_canonicalize_min_2(%A : memref<i64>) {
216  %c1 = arith.constant 1 : index
217  %c2 = arith.constant 2 : index
218  %c7 = arith.constant 7 : index
219
220  scf.parallel (%i) = (%c1) to (%c7) step (%c2) {
221    %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7]
222    %2 = arith.index_cast %1: index to i64
223    memref.store %2, %A[]: memref<i64>
224  }
225  return
226}
227
228// -----
229
230// CHECK-LABEL: func @tensor_dim_of_iter_arg(
231//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
232//       CHECK:   scf.for
233//       CHECK:     tensor.dim %[[t]]
234func.func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
235  %c0 = arith.constant 0 : index
236  %c1 = arith.constant 1 : index
237  %c10 = arith.constant 10 : index
238  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
239      -> (tensor<?x?xf32>, index) {
240    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
241    scf.yield %arg0, %dim : tensor<?x?xf32>, index
242  }
243  return %1 : index
244}
245
246// -----
247
248// CHECK-LABEL: func @tensor_dim_of_iter_arg_insertslice(
249//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>,
250//       CHECK:   scf.for
251//       CHECK:     tensor.dim %[[t]]
252func.func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
253                                         %t2 : tensor<10x10xf32>) -> index {
254  %c0 = arith.constant 0 : index
255  %c1 = arith.constant 1 : index
256  %c10 = arith.constant 10 : index
257  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
258      -> (tensor<?x?xf32>, index) {
259    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
260    %2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1]
261        : tensor<10x10xf32> into tensor<?x?xf32>
262    %3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1]
263        : tensor<10x10xf32> into tensor<?x?xf32>
264    scf.yield %3, %dim : tensor<?x?xf32>, index
265  }
266  return %1 : index
267}
268
269// -----
270
271// CHECK-LABEL: func @tensor_dim_of_iter_arg_nested_for(
272//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>,
273//       CHECK:   scf.for
274//       CHECK:     scf.for
275//       CHECK:       tensor.dim %[[t]]
276func.func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
277                                        %t2 : tensor<10x10xf32>) -> index {
278  %c0 = arith.constant 0 : index
279  %c1 = arith.constant 1 : index
280  %c10 = arith.constant 10 : index
281  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
282      -> (tensor<?x?xf32>, index) {
283    %2, %3 = scf.for %j = %c0 to %c10 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1)
284        -> (tensor<?x?xf32>, index) {
285      %dim = tensor.dim %arg2, %c0 : tensor<?x?xf32>
286      %4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1]
287          : tensor<10x10xf32> into tensor<?x?xf32>
288      scf.yield %4, %dim : tensor<?x?xf32>, index
289    }
290    scf.yield %2, %3 : tensor<?x?xf32>, index
291  }
292  return %1 : index
293}
294
295
296// -----
297
298// A test case that should not canonicalize because the loop is not shape
299// conserving.
300
301// CHECK-LABEL: func @tensor_dim_of_iter_arg_no_canonicalize(
302//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>,
303//       CHECK:   scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[t]]
304//       CHECK:     tensor.dim %[[arg0]]
305func.func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor<?x?xf32>,
306                                             %t2 : tensor<?x?xf32>) -> index {
307  %c0 = arith.constant 0 : index
308  %c1 = arith.constant 1 : index
309  %c10 = arith.constant 10 : index
310  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
311      -> (tensor<?x?xf32>, index) {
312    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
313    scf.yield %t2, %dim : tensor<?x?xf32>, index
314  }
315  return %1 : index
316}
317
318// -----
319
320// CHECK-LABEL: func @tensor_dim_of_loop_result(
321//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
322//       CHECK:   tensor.dim %[[t]]
323func.func @tensor_dim_of_loop_result(%t : tensor<?x?xf32>) -> index {
324  %c0 = arith.constant 0 : index
325  %c1 = arith.constant 1 : index
326  %c10 = arith.constant 10 : index
327  %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t)
328      -> (tensor<?x?xf32>) {
329    scf.yield %arg0 : tensor<?x?xf32>
330  }
331  %dim = tensor.dim %0, %c0 : tensor<?x?xf32>
332  return %dim : index
333}
334
335// -----
336
337// CHECK-LABEL: func @tensor_dim_of_loop_result_no_canonicalize(
338//       CHECK:   %[[loop:.*]]:2 = scf.for
339//       CHECK:   tensor.dim %[[loop]]#1
340func.func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor<?x?xf32>,
341                                                %u : tensor<?x?xf32>) -> index {
342  %c0 = arith.constant 0 : index
343  %c1 = arith.constant 1 : index
344  %c10 = arith.constant 10 : index
345  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %u)
346      -> (tensor<?x?xf32>, tensor<?x?xf32>) {
347    scf.yield %arg0, %u : tensor<?x?xf32>, tensor<?x?xf32>
348  }
349  %dim = tensor.dim %1, %c0 : tensor<?x?xf32>
350  return %dim : index
351}
352
353// -----
354
355// CHECK-LABEL: func @one_trip_scf_for_canonicalize_min
356//       CHECK:   %[[C4:.*]] = arith.constant 4 : i64
357//       CHECK:   scf.for
358//       CHECK:     memref.store %[[C4]], %{{.*}}[] : memref<i64>
359func.func @one_trip_scf_for_canonicalize_min(%A : memref<i64>) {
360  %c0 = arith.constant 0 : index
361  %c2 = arith.constant 2 : index
362  %c4 = arith.constant 4 : index
363
364  scf.for %i = %c0 to %c4 step %c4 {
365    %1 = affine.min affine_map<(d0, d1)[] -> (4, d1 - d0)> (%i, %c4)
366    %2 = arith.index_cast %1: index to i64
367    memref.store %2, %A[]: memref<i64>
368  }
369  return
370}
371
372// -----
373
374// This is a regression test to ensure that the no assertions are failing.
375
376//       CHECK: #[[$map:.+]] = affine_map<(d0)[s0] -> (-(d0 * (5 ceildiv s0)) + 5, 3)>
377// CHECK-LABEL: func @regression_multiplication_with_sym
378func.func @regression_multiplication_with_sym(%A : memref<i64>) {
379  %c0 = arith.constant 0 : index
380  %c1 = arith.constant 1 : index
381  %c2 = arith.constant 2 : index
382  %c4 = arith.constant 4 : index
383  // CHECK: %[[dummy:.*]] = "test.dummy"
384  %ub = "test.dummy"() : () -> (index)
385  // CHECK: scf.for %[[iv:.*]] =
386  scf.for %i = %c0 to %ub step %c1 {
387    // CHECK: affine.min #[[$map]](%[[iv]])[%[[dummy]]]
388    %1 = affine.min affine_map<(d0)[s0] -> (-(d0 * (5 ceildiv s0)) + 5, 3)>(%i)[%ub]
389    %2 = arith.index_cast %1: index to i64
390    memref.store %2, %A[]: memref<i64>
391  }
392  return
393}
394
395// -----
396
397
398// Make sure min is transformed into zero.
399
400// CHECK-LABEL: func.func @func1()
401//       CHECK:   %[[ZERO:.+]] = arith.constant 0 : index
402//       CHECK:   call @foo(%[[ZERO]]) : (index) -> ()
403
404#map6 = affine_map<(d0, d1, d2) -> (d0 floordiv 64)>
405#map29 = affine_map<(d0, d1, d2) -> (d2 * 64 - 2, 5, (d1 mod 4) floordiv 8)>
406module {
407  func.func private @foo(%0 : index) -> ()
408
409  func.func @func1() {
410    %true = arith.constant true
411    %c0 = arith.constant 0 : index
412    %c5 = arith.constant 5 : index
413    %c11 = arith.constant 11 : index
414    %c14 = arith.constant 14 : index
415    %c15 = arith.constant 15 : index
416    %alloc_249 = memref.alloc() : memref<7xf32>
417    %135 = affine.apply #map6(%c15, %c0, %c14)
418    %163 = affine.min #map29(%c5, %135, %c11)
419    func.call @foo(%163) : (index) -> ()
420    return
421  }
422}
423