xref: /llvm-project/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir (revision f8b27949a8c4fa8d8e15f9858e2ed38d7267f7dd)
1// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds{reify-to-func-args}))' \
2// RUN:     -verify-diagnostics -split-input-file | FileCheck %s
3
4// CHECK-LABEL: func @scf_for(
5//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
6//       CHECK:   "test.some_use"(%[[a]], %[[b]])
7func.func @scf_for(%a: index, %b: index, %c: index) {
8  scf.for %iv = %a to %b step %c {
9    %0 = "test.reify_bound"(%iv) {type = "LB"} : (index) -> (index)
10    %1 = "test.reify_bound"(%iv) {type = "UB"} : (index) -> (index)
11    "test.some_use"(%0, %1) : (index, index) -> ()
12  }
13  return
14}
15
16// -----
17
18// CHECK-LABEL: func @scf_for_index_result_small(
19//  CHECK-SAME:     %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
20//       CHECK:   "test.some_use"(%[[i]])
21//       CHECK:   "test.some_use"(%[[i]])
22func.func @scf_for_index_result_small(%i: index, %a: index, %b: index, %c: index) {
23  %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index {
24    %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index)
25    "test.some_use"(%1) : (index) -> ()
26    scf.yield %arg : index
27  }
28  %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
29  "test.some_use"(%2) : (index) -> ()
30  return
31}
32
33// -----
34
35// CHECK-LABEL: func @scf_for_index_result(
36//  CHECK-SAME:     %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
37//       CHECK:   "test.some_use"(%[[i]])
38//       CHECK:   "test.some_use"(%[[i]])
39func.func @scf_for_index_result(%i: index, %a: index, %b: index, %c: index) {
40  %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index {
41    %add = arith.addi %arg, %a : index
42    %sub = arith.subi %add, %a : index
43
44    %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index)
45    "test.some_use"(%1) : (index) -> ()
46    scf.yield %sub : index
47  }
48  %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
49  "test.some_use"(%2) : (index) -> ()
50  return
51}
52
53// -----
54
55// CHECK-LABEL: func @scf_for_tensor_result_small(
56//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
57//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]]
58//       CHECK:   "test.some_use"(%[[dim]])
59//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]]
60//       CHECK:   "test.some_use"(%[[dim]])
61func.func @scf_for_tensor_result_small(%t: tensor<?xf32>, %a: index, %b: index, %c: index) {
62  %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor<?xf32> {
63    %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
64    "test.some_use"(%1) : (index) -> ()
65    scf.yield %arg : tensor<?xf32>
66  }
67  %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
68  "test.some_use"(%2) : (index) -> ()
69  return
70}
71
72// -----
73
74// CHECK-LABEL: func @scf_for_tensor_result(
75//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
76//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]]
77//       CHECK:   "test.some_use"(%[[dim]])
78//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]]
79//       CHECK:   "test.some_use"(%[[dim]])
80func.func @scf_for_tensor_result(%t: tensor<?xf32>, %a: index, %b: index, %c: index) {
81  %cst = arith.constant 5.0 : f32
82  %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor<?xf32> {
83    %filled = linalg.fill ins(%cst : f32) outs(%arg : tensor<?xf32>) -> tensor<?xf32>
84    %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
85    "test.some_use"(%1) : (index) -> ()
86    scf.yield %filled : tensor<?xf32>
87  }
88  %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
89  "test.some_use"(%2) : (index) -> ()
90  return
91}
92
93// -----
94
95func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: index, %b: index, %c: index) {
96  %cst = arith.constant 5.0 : f32
97  %r1, %r2 = scf.for %iv = %a to %b step %c iter_args(%arg1 = %t1, %arg2 = %t2) -> (tensor<?xf32>, tensor<?xf32>) {
98    %filled1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
99    %filled2 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
100    scf.yield %filled2, %filled1 : tensor<?xf32>, tensor<?xf32>
101  }
102  // expected-error @below{{could not reify bound}}
103  %reify1 = "test.reify_bound"(%r1) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
104  "test.some_use"(%reify1) : (index) -> ()
105  return
106}
107
108// -----
109
110// CHECK-LABEL: func @scf_forall(
111//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
112//       CHECK:   "test.some_use"(%[[a]], %[[b]])
113func.func @scf_forall(%a: index, %b: index, %c: index) {
114  scf.forall (%iv) = (%a) to (%b) step (%c) {
115    %0 = "test.reify_bound"(%iv) {type = "LB"} : (index) -> (index)
116    %1 = "test.reify_bound"(%iv) {type = "UB"} : (index) -> (index)
117    "test.some_use"(%0, %1) : (index, index) -> ()
118  }
119  return
120}
121
122// -----
123
124// CHECK-LABEL: func @scf_forall_tensor_result(
125//  CHECK-SAME:     %[[size:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
126//       CHECK:   "test.some_use"(%[[size]])
127//       CHECK:   "test.some_use"(%[[size]])
128func.func @scf_forall_tensor_result(%size: index, %a: index, %b: index, %c: index) {
129  %cst = arith.constant 5.0 : f32
130  %empty = tensor.empty(%size) : tensor<?xf32>
131  %0 = scf.forall (%iv) = (%a) to (%b) step (%c) shared_outs(%arg = %empty) -> tensor<?xf32> {
132    %filled = linalg.fill ins(%cst : f32) outs(%arg : tensor<?xf32>) -> tensor<?xf32>
133    %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
134    "test.some_use"(%1) : (index) -> ()
135    scf.forall.in_parallel {
136      tensor.parallel_insert_slice %filled into %arg[0][%size][1] : tensor<?xf32> into tensor<?xf32>
137    }
138  }
139  %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
140  "test.some_use"(%2) : (index) -> ()
141  return
142}
143
144// -----
145
146// CHECK-LABEL: func @scf_if_constant(
147func.func @scf_if_constant(%c : i1) {
148  // CHECK: arith.constant 4 : index
149  // CHECK: arith.constant 9 : index
150  %c4 = arith.constant 4 : index
151  %c9 = arith.constant 9 : index
152  %r = scf.if %c -> index {
153    scf.yield %c4 : index
154  } else {
155    scf.yield %c9 : index
156  }
157
158  // CHECK: %[[c4:.*]] = arith.constant 4 : index
159  // CHECK: %[[c10:.*]] = arith.constant 10 : index
160  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
161  %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
162  // CHECK: "test.some_use"(%[[c4]], %[[c10]])
163  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
164  return
165}
166
167// -----
168
169// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
170// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
171// CHECK-LABEL: func @scf_if_dynamic(
172//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
173func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) {
174  %c4 = arith.constant 4 : index
175  %r = scf.if %c -> index {
176    %add1 = arith.addi %a, %b : index
177    scf.yield %add1 : index
178  } else {
179    %add2 = arith.addi %b, %c4 : index
180    %add3 = arith.addi %add2, %a : index
181    scf.yield %add3 : index
182  }
183
184  // CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
185  // CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]]
186  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
187  %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
188  // CHECK: "test.some_use"(%[[lb]], %[[ub]])
189  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
190  return
191}
192
193// -----
194
195func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) {
196  %r = scf.if %c -> index {
197    scf.yield %a : index
198  } else {
199    scf.yield %b : index
200  }
201  // The reified bound would be min(%a, %b). min/max expressions are not
202  // supported in reified bounds.
203  // expected-error @below{{could not reify bound}}
204  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
205  "test.some_use"(%reify1) : (index) -> ()
206  return
207}
208
209// -----
210
211// CHECK-LABEL: func @scf_if_tensor_dim(
212func.func @scf_if_tensor_dim(%c : i1) {
213  // CHECK: arith.constant 4 : index
214  // CHECK: arith.constant 9 : index
215  %c4 = arith.constant 4 : index
216  %c9 = arith.constant 9 : index
217  %t1 = tensor.empty(%c4) : tensor<?xf32>
218  %t2 = tensor.empty(%c9) : tensor<?xf32>
219  %r = scf.if %c -> tensor<?xf32> {
220    scf.yield %t1 : tensor<?xf32>
221  } else {
222    scf.yield %t2 : tensor<?xf32>
223  }
224
225  // CHECK: %[[c4:.*]] = arith.constant 4 : index
226  // CHECK: %[[c10:.*]] = arith.constant 10 : index
227  %reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0}
228      : (tensor<?xf32>) -> (index)
229  %reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0}
230      : (tensor<?xf32>) -> (index)
231  // CHECK: "test.some_use"(%[[c4]], %[[c10]])
232  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
233  return
234}
235
236// -----
237
238// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
239// CHECK-LABEL: func @scf_if_eq(
240//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
241func.func @scf_if_eq(%a: index, %b: index, %c : i1) {
242  %c0 = arith.constant 0 : index
243  %r = scf.if %c -> index {
244    %add1 = arith.addi %a, %b : index
245    scf.yield %add1 : index
246  } else {
247    %add2 = arith.addi %b, %c0 : index
248    %add3 = arith.addi %add2, %a : index
249    scf.yield %add3 : index
250  }
251
252  // CHECK: %[[eq:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
253  %reify1 = "test.reify_bound"(%r) {type = "EQ"} : (index) -> (index)
254  // CHECK: "test.some_use"(%[[eq]])
255  "test.some_use"(%reify1) : (index) -> ()
256  return
257}
258
259// -----
260
261func.func @compare_scf_for(%a: index, %b: index, %c: index) {
262  scf.for %iv = %a to %b step %c {
263    // expected-remark @below{{true}}
264    "test.compare"(%iv, %a) {cmp = "GE"} : (index, index) -> ()
265    // expected-remark @below{{true}}
266    "test.compare"(%iv, %b) {cmp = "LT"} : (index, index) -> ()
267  }
268  return
269}
270
271// -----
272
273func.func @scf_for_result_infer() {
274  %c0 = arith.constant 0 : index
275  %c1 = arith.constant 1 : index
276  %c10 = arith.constant 10 : index
277  %0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %c0) -> index {
278    %2 = "test.some_use"() : () -> (i1)
279    %3 = scf.if %2 -> (index) {
280        %5 = arith.addi %arg, %c1 : index
281        scf.yield %5 : index
282    } else {
283        scf.yield %arg : index
284    }
285    scf.yield %3 : index
286  }
287  // expected-remark @below{{true}}
288  "test.compare"(%0, %c10) {cmp = "LE"} : (index, index) -> ()
289  return
290}
291
292// -----
293
294func.func @scf_for_result_infer_dynamic_init(%i : index) {
295  %c0 = arith.constant 0 : index
296  %c1 = arith.constant 1 : index
297  %c10 = arith.constant 10 : index
298  %0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %i) -> index {
299    %2 = "test.some_use"() : () -> (i1)
300    %3 = scf.if %2 -> (index) {
301        %5 = arith.addi %arg, %c1 : index
302        scf.yield %5 : index
303    } else {
304        scf.yield %arg : index
305    }
306    scf.yield %3 : index
307  }
308  %6 = arith.addi %i, %c10 : index
309  // expected-remark @below{{true}}
310  "test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> ()
311  return
312}
313
314// -----
315
316func.func @scf_for_result_infer_dynamic_init_big_step(%i : index) {
317  %c0 = arith.constant 0 : index
318  %c1 = arith.constant 1 : index
319  %c2 = arith.constant 2 : index
320  %c4 = arith.constant 4 : index
321  %c5 = arith.constant 5 : index
322  %c10 = arith.constant 10 : index
323  %0 = scf.for %iv = %c0 to %c10 step %c2 iter_args(%arg = %i) -> index {
324    %2 = "test.some_use"() : () -> (i1)
325    %3 = scf.if %2 -> (index) {
326        %5 = arith.addi %arg, %c1 : index
327        scf.yield %5 : index
328    } else {
329        scf.yield %arg : index
330    }
331    scf.yield %3 : index
332  }
333  %6 = arith.addi %i, %c5 : index
334  %7 = arith.addi %i, %c4 : index
335  // expected-remark @below{{true}}
336  "test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> ()
337  // expected-error @below{{unknown}}
338  "test.compare"(%0, %7) {cmp = "LE"} : (index, index) -> ()
339  return
340}
341