xref: /llvm-project/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (revision 6050cf28846e5be2c162108f1a024d5ff25d5637)
1// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(scf-parallel-loop-fusion))' -split-input-file | FileCheck %s
2
3func.func @fuse_empty_loops() {
4  %c2 = arith.constant 2 : index
5  %c0 = arith.constant 0 : index
6  %c1 = arith.constant 1 : index
7  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
8    scf.reduce
9  }
10  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
11    scf.reduce
12  }
13  return
14}
15// CHECK-LABEL: func @fuse_empty_loops
16// CHECK-DAG:    [[C2:%.*]] = arith.constant 2 : index
17// CHECK-DAG:    [[C0:%.*]] = arith.constant 0 : index
18// CHECK-DAG:    [[C1:%.*]] = arith.constant 1 : index
19// CHECK:        scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
20// CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
21// CHECK:          scf.reduce
22// CHECK:        }
23// CHECK-NOT:    scf.parallel
24
25// -----
26
27func.func @fuse_ops_between(%A: f32, %B: f32) -> f32 {
28  %c2 = arith.constant 2 : index
29  %c0 = arith.constant 0 : index
30  %c1 = arith.constant 1 : index
31  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
32    scf.reduce
33  }
34  %res = arith.addf %A, %B : f32
35  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
36    scf.reduce
37  }
38  return %res : f32
39}
40// CHECK-LABEL: func @fuse_ops_between
41// CHECK-DAG:    [[C0:%.*]] = arith.constant 0 : index
42// CHECK-DAG:    [[C1:%.*]] = arith.constant 1 : index
43// CHECK-DAG:    [[C2:%.*]] = arith.constant 2 : index
44// CHECK:        %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
45// CHECK:        scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
46// CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
47// CHECK:          scf.reduce
48// CHECK:        }
49// CHECK-NOT:    scf.parallel
50
51// -----
52
53func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
54  %c2 = arith.constant 2 : index
55  %c0 = arith.constant 0 : index
56  %c1 = arith.constant 1 : index
57  %c1fp = arith.constant 1.0 : f32
58  %sum = memref.alloc()  : memref<2x2xf32>
59  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
60    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
61    %sum_elem = arith.addf %B_elem, %c1fp : f32
62    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
63    scf.reduce
64  }
65  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
66    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
67    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
68    %product_elem = arith.mulf %sum_elem, %A_elem : f32
69    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
70    scf.reduce
71  }
72  memref.dealloc %sum : memref<2x2xf32>
73  return
74}
75// CHECK-LABEL: func @fuse_two
76// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
77// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
78// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
79// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
80// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
81// CHECK:      [[SUM:%.*]] = memref.alloc()
82// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
83// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
84// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
85// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
86// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
87// CHECK-NOT:  scf.parallel
88// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
89// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
90// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
91// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
92// CHECK:        scf.reduce
93// CHECK:      }
94// CHECK:      memref.dealloc [[SUM]]
95
96// -----
97
98func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
99  %c2 = arith.constant 2 : index
100  %c0 = arith.constant 0 : index
101  %c1 = arith.constant 1 : index
102  %c1fp = arith.constant 1.0 : f32
103  %c2fp = arith.constant 2.0 : f32
104  %sum = memref.alloc()  : memref<2x2xf32>
105  %prod = memref.alloc()  : memref<2x2xf32>
106  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
107    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
108    %sum_elem = arith.addf %B_elem, %c1fp : f32
109    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
110    scf.reduce
111  }
112  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
113    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
114    %product_elem = arith.mulf %sum_elem, %c2fp : f32
115    memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
116    scf.reduce
117  }
118  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
119    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
120    %res_elem = arith.addf %A_elem, %c2fp : f32
121    memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
122  }
123  memref.dealloc %sum : memref<2x2xf32>
124  memref.dealloc %prod : memref<2x2xf32>
125  return
126}
127// CHECK-LABEL: func @fuse_three
128// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
129// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
130// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
131// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
132// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
133// CHECK-DAG:  [[C2FP:%.*]] = arith.constant 2.
134// CHECK:      [[SUM:%.*]] = memref.alloc()
135// CHECK:      [[PROD:%.*]] = memref.alloc()
136// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
137// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
138// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
139// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
140// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
141// CHECK-NOT:  scf.parallel
142// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
143// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]]
144// CHECK:        memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]]
145// CHECK-NOT:  scf.parallel
146// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
147// CHECK:        [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]]
148// CHECK:        memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
149// CHECK:        scf.reduce
150// CHECK:      }
151// CHECK:      memref.dealloc [[SUM]]
152// CHECK:      memref.dealloc [[PROD]]
153
154// -----
155
156func.func @do_not_fuse_nested_ploop1() {
157  %c2 = arith.constant 2 : index
158  %c0 = arith.constant 0 : index
159  %c1 = arith.constant 1 : index
160  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
161    scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
162      scf.reduce
163    }
164    scf.reduce
165  }
166  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
167    scf.reduce
168  }
169  return
170}
171// CHECK-LABEL: func @do_not_fuse_nested_ploop1
172// CHECK:        scf.parallel
173// CHECK:          scf.parallel
174// CHECK:        scf.parallel
175
176// -----
177
178func.func @do_not_fuse_nested_ploop2() {
179  %c2 = arith.constant 2 : index
180  %c0 = arith.constant 0 : index
181  %c1 = arith.constant 1 : index
182  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
183    scf.reduce
184  }
185  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
186    scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
187      scf.reduce
188    }
189    scf.reduce
190  }
191  return
192}
193// CHECK-LABEL: func @do_not_fuse_nested_ploop2
194// CHECK:        scf.parallel
195// CHECK:        scf.parallel
196// CHECK:          scf.parallel
197
198// -----
199
200func.func @do_not_fuse_loops_unmatching_num_loops() {
201  %c2 = arith.constant 2 : index
202  %c0 = arith.constant 0 : index
203  %c1 = arith.constant 1 : index
204  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
205    scf.reduce
206  }
207  scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
208    scf.reduce
209  }
210  return
211}
212// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops
213// CHECK:        scf.parallel
214// CHECK:        scf.parallel
215
216// -----
217
218func.func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
219  %c2 = arith.constant 2 : index
220  %c0 = arith.constant 0 : index
221  %c1 = arith.constant 1 : index
222  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
223    scf.reduce
224  }
225  %buffer  = memref.alloc() : memref<2x2xf32>
226  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
227    scf.reduce
228  }
229  return
230}
231// CHECK-LABEL: func @do_not_fuse_loops_with_side_effecting_ops_in_between
232// CHECK:        scf.parallel
233// CHECK:        scf.parallel
234
235// -----
236
237func.func @do_not_fuse_loops_unmatching_iteration_space() {
238  %c0 = arith.constant 0 : index
239  %c1 = arith.constant 1 : index
240  %c2 = arith.constant 2 : index
241  %c4 = arith.constant 4 : index
242  scf.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) {
243    scf.reduce
244  }
245  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
246    scf.reduce
247  }
248  return
249}
250// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space
251// CHECK:        scf.parallel
252// CHECK:        scf.parallel
253
254// -----
255
256func.func @do_not_fuse_unmatching_write_read_patterns(
257    %A: memref<2x2xf32>, %B: memref<2x2xf32>,
258    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
259  %c2 = arith.constant 2 : index
260  %c0 = arith.constant 0 : index
261  %c1 = arith.constant 1 : index
262  %common_buf = memref.alloc() : memref<2x2xf32>
263  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
264    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
265    %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
266    %sum_elem = arith.addf %B_elem, %C_elem : f32
267    memref.store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32>
268    scf.reduce
269  }
270  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
271    %k = arith.addi %i, %c1 : index
272    %sum_elem = memref.load %common_buf[%k, %j] : memref<2x2xf32>
273    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
274    %product_elem = arith.mulf %sum_elem, %A_elem : f32
275    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
276    scf.reduce
277  }
278  memref.dealloc %common_buf : memref<2x2xf32>
279  return
280}
281// CHECK-LABEL: func @do_not_fuse_unmatching_write_read_patterns
282// CHECK:        scf.parallel
283// CHECK:        scf.parallel
284
285// -----
286
287func.func @do_not_fuse_unmatching_read_write_patterns(
288    %A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) {
289  %c2 = arith.constant 2 : index
290  %c0 = arith.constant 0 : index
291  %c1 = arith.constant 1 : index
292  %sum = memref.alloc() : memref<2x2xf32>
293  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
294    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
295    %C_elem = memref.load %common_buf[%i, %j] : memref<2x2xf32>
296    %sum_elem = arith.addf %B_elem, %C_elem : f32
297    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
298    scf.reduce
299  }
300  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
301    %k = arith.addi %i, %c1 : index
302    %sum_elem = memref.load %sum[%k, %j] : memref<2x2xf32>
303    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
304    %product_elem = arith.mulf %sum_elem, %A_elem : f32
305    memref.store %product_elem, %common_buf[%j, %i] : memref<2x2xf32>
306    scf.reduce
307  }
308  memref.dealloc %sum : memref<2x2xf32>
309  return
310}
311// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns
312// CHECK:        scf.parallel
313// CHECK:        scf.parallel
314
315// -----
316
317func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
318  %c2 = arith.constant 2 : index
319  %c0 = arith.constant 0 : index
320  %c1 = arith.constant 1 : index
321  %buffer  = memref.alloc() : memref<2x2xf32>
322  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
323    scf.reduce
324  }
325  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
326    %A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
327      : memref<2x2xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
328    %A_elem = memref.load %A[%i, %j] : memref<?x?xf32, strided<[?, ?], offset: ?>>
329    scf.reduce
330  }
331  return
332}
333// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
334// CHECK:        scf.parallel
335// CHECK:        scf.parallel
336
337// -----
338
339func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
340  %c2 = arith.constant 2 : index
341  %c0 = arith.constant 0 : index
342  %c1 = arith.constant 1 : index
343  %c1fp = arith.constant 1.0 : f32
344  %sum = memref.alloc()  : memref<2x2xf32>
345  scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
346    scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
347      %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
348      %sum_elem = arith.addf %B_elem, %c1fp : f32
349      memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
350      scf.reduce
351    }
352    scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
353      %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
354      %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
355      %product_elem = arith.mulf %sum_elem, %A_elem : f32
356      memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
357      scf.reduce
358    }
359  }
360  memref.dealloc %sum : memref<2x2xf32>
361  return
362}
363// CHECK-LABEL: func @nested_fuse
364// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
365// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
366// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
367// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
368// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
369// CHECK:      [[SUM:%.*]] = memref.alloc()
370// CHECK:      scf.parallel
371// CHECK:        scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
372// CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
373// CHECK:          [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
374// CHECK:          [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
375// CHECK:          memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
376// CHECK-NOT:   scf.parallel
377// CHECK:          [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
378// CHECK:          [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
379// CHECK:          [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
380// CHECK:          memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
381// CHECK:          scf.reduce
382// CHECK:        }
383// CHECK:      }
384// CHECK:      memref.dealloc [[SUM]]
385
386// -----
387
388func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
389                             %C: memref<2x2xf32>, %result: memref<2x2xf32>,
390                             %sum: memref<2x2xf32>) {
391  %c2 = arith.constant 2 : index
392  %c0 = arith.constant 0 : index
393  %c1 = arith.constant 1 : index
394  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
395    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
396    %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
397    %sum_elem = arith.addf %B_elem, %C_elem : f32
398    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
399    scf.reduce
400  }
401  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
402    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
403    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
404    %product_elem = arith.mulf %sum_elem, %A_elem : f32
405    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
406    scf.reduce
407  }
408  return
409}
410// %sum and %result may alias with other args, do not fuse loops
411// CHECK-LABEL: func @do_not_fuse_alias
412// CHECK:      scf.parallel
413// CHECK:      scf.parallel
414
415// -----
416
417func.func @fuse_when_1st_has_multiple_stores(
418  %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
419  %c0 = arith.constant 0 : index
420  %c1 = arith.constant 1 : index
421  %c2 = arith.constant 2 : index
422  %c0fp = arith.constant 0.0 : f32
423  %sum = memref.alloc()  : memref<2x2xf32>
424  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
425    memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
426    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
427    %sum_elem = arith.addf %B_elem, %B_elem : f32
428    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
429    scf.reduce
430  }
431  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
432    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
433    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
434    %product_elem = arith.mulf %sum_elem, %A_elem : f32
435    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
436    scf.reduce
437  }
438  memref.dealloc %sum : memref<2x2xf32>
439  return
440}
441// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
442// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
443// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
444// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
445// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
446// CHECK-DAG:  [[C0F32:%.*]] = arith.constant 0.
447// CHECK:      [[SUM:%.*]] = memref.alloc()
448// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
449// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
450// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
451// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
452// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
453// CHECK-NOT:  scf.parallel
454// CHECK:        [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
455// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
456// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf
457// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
458// CHECK:        scf.reduce
459// CHECK:      }
460// CHECK:      memref.dealloc [[SUM]]
461
462// -----
463
464func.func @do_not_fuse_multiple_stores_on_diff_indices(
465  %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
466  %c0 = arith.constant 0 : index
467  %c1 = arith.constant 1 : index
468  %c2 = arith.constant 2 : index
469  %c0fp = arith.constant 0.0 : f32
470  %sum = memref.alloc()  : memref<2x2xf32>
471  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
472    memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
473    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
474    %sum_elem = arith.addf %B_elem, %B_elem : f32
475    memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
476    scf.reduce
477  }
478  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
479    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
480    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
481    %product_elem = arith.mulf %sum_elem, %A_elem : f32
482    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
483    scf.reduce
484  }
485  memref.dealloc %sum : memref<2x2xf32>
486  return
487}
488// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
489// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
490// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
491// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
492// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
493// CHECK-DAG:  [[C0F32:%.*]] = arith.constant 0.
494// CHECK:      [[SUM:%.*]] = memref.alloc()
495// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
496// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
497// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
498// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
499// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]]
500// CHECK:        scf.reduce
501// CHECK:     scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
502// CHECK:        [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
503// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
504// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf
505// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
506// CHECK:        scf.reduce
507// CHECK:      }
508// CHECK:      memref.dealloc [[SUM]]
509
510// -----
511
512func.func @fuse_same_indices_by_affine_apply(
513  %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
514  %c0 = arith.constant 0 : index
515  %c1 = arith.constant 1 : index
516  %c2 = arith.constant 2 : index
517  %sum = memref.alloc()  : memref<2x3xf32>
518  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
519    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
520    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
521    memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
522    scf.reduce
523  }
524  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
525    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
526    %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
527    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
528    %product = arith.mulf %sum_elem, %A_elem : f32
529    memref.store %product, %B[%i, %j] : memref<2x2xf32>
530    scf.reduce
531  }
532  memref.dealloc %sum : memref<2x3xf32>
533  return
534}
535// CHECK:      #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
536// CHECK-LABEL: fuse_same_indices_by_affine_apply
537// CHECK-SAME:  (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>) {
538// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
539// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
540// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
541// CHECK:       %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
542// CHECK-NEXT:  scf.parallel (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
543// CHECK-NEXT:    %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
544// CHECK-NEXT:    %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
545// CHECK-NEXT:    memref.store %[[S0]], %[[ALLOC]][%[[ARG2]], %[[S1]]] : memref<2x3xf32>
546// CHECK-NEXT:    %[[S2:.*]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
547// CHECK-NEXT:    %[[S3:.*]] = memref.load %[[ALLOC]][%[[ARG2]], %[[S2]]] : memref<2x3xf32>
548// CHECK-NEXT:    %[[S4:.*]] = memref.load %[[ARG0]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
549// CHECK-NEXT:    %[[S5:.*]] = arith.mulf %[[S3]], %[[S4]] : f32
550// CHECK-NEXT:    memref.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
551// CHECK-NEXT:    scf.reduce
552// CHECK-NEXT:  }
553// CHECK-NEXT:  memref.dealloc %[[ALLOC]] : memref<2x3xf32>
554// CHECK-NEXT:  return
555
556// -----
557
558func.func @do_not_fuse_affine_apply_to_non_ind_var(
559  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %OffsetA: index, %OffsetB: index) {
560  %c0 = arith.constant 0 : index
561  %c1 = arith.constant 1 : index
562  %c2 = arith.constant 2 : index
563  %sum = memref.alloc()  : memref<2x3xf32>
564  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
565    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
566    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetA)
567    memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
568    scf.reduce
569  }
570  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
571    %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %OffsetB)
572    %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
573    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
574    %product = arith.mulf %sum_elem, %A_elem : f32
575    memref.store %product, %B[%i, %j] : memref<2x2xf32>
576    scf.reduce
577  }
578  memref.dealloc %sum : memref<2x3xf32>
579  return
580}
581// CHECK:       #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
582// CHECK-LABEL: do_not_fuse_affine_apply_to_non_ind_var
583// CHECK-SAME:  (%[[ARG0:.*]]: memref<2x2xf32>, %[[ARG1:.*]]: memref<2x2xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) {
584// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
585// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
586// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
587// CHECK:         %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
588// CHECK-NEXT:    scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
589// CHECK-NEXT:      %[[S0:.*]] = memref.load %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
590// CHECK-NEXT:      %[[S1:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG2]])
591// CHECK-NEXT:      memref.store %[[S0]], %[[ALLOC]][%[[ARG4]], %[[S1]]] : memref<2x3xf32>
592// CHECK-NEXT:      scf.reduce
593// CHECK-NEXT:    }
594// CHECK-NEXT:    scf.parallel (%[[ARG4:.*]], %[[ARG5:.*]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
595// CHECK-NEXT:      %[[S0:.*]] = affine.apply #[[$MAP]](%[[ARG4]], %[[ARG3]])
596// CHECK-NEXT:      %[[S1:.*]] = memref.load %[[ALLOC]][%[[ARG4]], %[[S0]]] : memref<2x3xf32>
597// CHECK-NEXT:      %[[S2:.*]] = memref.load %[[ARG0]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
598// CHECK-NEXT:      %[[S3:.*]] = arith.mulf %[[S1]], %[[S2]] : f32
599// CHECK-NEXT:      memref.store %[[S3]], %[[ARG1]][%[[ARG4]], %[[ARG5]]] : memref<2x2xf32>
600// CHECK-NEXT:      scf.reduce
601// CHECK-NEXT:    }
602// CHECK-NEXT:    memref.dealloc %[[ALLOC]] : memref<2x3xf32>
603// CHECK-NEXT:    return
604
605// -----
606
607func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
608  %c2 = arith.constant 2 : index
609  %c0 = arith.constant 0 : index
610  %c1 = arith.constant 1 : index
611  %init1 = arith.constant 1.0 : f32
612  %init2 = arith.constant 2.0 : f32
613  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
614    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
615    scf.reduce(%A_elem : f32) {
616    ^bb0(%lhs: f32, %rhs: f32):
617      %1 = arith.addf %lhs, %rhs : f32
618      scf.reduce.return %1 : f32
619    }
620  }
621  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
622    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
623    scf.reduce(%B_elem : f32) {
624    ^bb0(%lhs: f32, %rhs: f32):
625      %1 = arith.mulf %lhs, %rhs : f32
626      scf.reduce.return %1 : f32
627    }
628  }
629  return %res1, %res2 : f32, f32
630}
631
632// CHECK-LABEL: func @fuse_reductions_two
633//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
634//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
635//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
636//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
637//   CHECK-DAG:   %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
638//   CHECK-DAG:   %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
639//       CHECK:   %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
640//  CHECK-SAME:   to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
641//  CHECK-SAME:   init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
642//       CHECK:   %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
643//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
644//       CHECK:   scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
645//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
646//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
647//       CHECK:     scf.reduce.return %[[R]] : f32
648//       CHECK:   }
649//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
650//       CHECK:     %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
651//       CHECK:     scf.reduce.return %[[R]] : f32
652//       CHECK:   }
653//       CHECK:   return %[[RES]]#0, %[[RES]]#1 : f32, f32
654
655// -----
656
657func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) {
658  %c2 = arith.constant 2 : index
659  %c0 = arith.constant 0 : index
660  %c1 = arith.constant 1 : index
661  %init1 = arith.constant 1.0 : f32
662  %init2 = arith.constant 2.0 : f32
663  %init3 = arith.constant 3.0 : f32
664  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
665    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
666    scf.reduce(%A_elem : f32) {
667    ^bb0(%lhs: f32, %rhs: f32):
668      %1 = arith.addf %lhs, %rhs : f32
669      scf.reduce.return %1 : f32
670    }
671  }
672  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
673    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
674    scf.reduce(%B_elem : f32) {
675    ^bb0(%lhs: f32, %rhs: f32):
676      %1 = arith.mulf %lhs, %rhs : f32
677      scf.reduce.return %1 : f32
678    }
679  }
680  %res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 {
681    %A_elem = memref.load %C[%i, %j] : memref<2x2xf32>
682    scf.reduce(%A_elem : f32) {
683    ^bb0(%lhs: f32, %rhs: f32):
684      %1 = arith.addf %lhs, %rhs : f32
685      scf.reduce.return %1 : f32
686    }
687  }
688  return %res1, %res2, %res3 : f32, f32, f32
689}
690
691// CHECK-LABEL: func @fuse_reductions_three
692//  CHECK-SAME:  (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
693//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
694//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
695//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
696//   CHECK-DAG:   %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
697//   CHECK-DAG:   %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
698//   CHECK-DAG:   %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
699//       CHECK:   %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
700//  CHECK-SAME:   to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
701//  CHECK-SAME:   init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
702//       CHECK:   %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
703//       CHECK:   %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
704//       CHECK:   %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
705//       CHECK:   scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
706//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
707//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
708//       CHECK:     scf.reduce.return %[[R]] : f32
709//       CHECK:   }
710//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
711//       CHECK:     %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
712//       CHECK:     scf.reduce.return %[[R]] : f32
713//       CHECK:   }
714//       CHECK:   ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
715//       CHECK:     %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
716//       CHECK:     scf.reduce.return %[[R]] : f32
717//       CHECK:   }
718//       CHECK:   return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
719
720// -----
721
722func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
723  %c2 = arith.constant 2 : index
724  %c0 = arith.constant 0 : index
725  %c1 = arith.constant 1 : index
726  %init1 = arith.constant 1.0 : f32
727  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
728    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
729    scf.reduce(%A_elem : f32) {
730    ^bb0(%lhs: f32, %rhs: f32):
731      %1 = arith.addf %lhs, %rhs : f32
732      scf.reduce.return %1 : f32
733    }
734  }
735  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
736    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
737    scf.reduce(%B_elem : f32) {
738    ^bb0(%lhs: f32, %rhs: f32):
739      %1 = arith.mulf %lhs, %rhs : f32
740      scf.reduce.return %1 : f32
741    }
742  }
743  return %res1, %res2 : f32, f32
744}
745
746// %res1 is used as second scf.parallel arg, cannot fuse
747// CHECK-LABEL: func @reductions_use_res
748// CHECK:      scf.parallel
749// CHECK:      scf.parallel
750
751// -----
752
753func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
754  %c2 = arith.constant 2 : index
755  %c0 = arith.constant 0 : index
756  %c1 = arith.constant 1 : index
757  %init1 = arith.constant 1.0 : f32
758  %init2 = arith.constant 2.0 : f32
759  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
760    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
761    scf.reduce(%A_elem : f32) {
762    ^bb0(%lhs: f32, %rhs: f32):
763      %1 = arith.addf %lhs, %rhs : f32
764      scf.reduce.return %1 : f32
765    }
766  }
767  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
768    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
769    %sum = arith.addf %B_elem, %res1 : f32
770    scf.reduce(%sum : f32) {
771    ^bb0(%lhs: f32, %rhs: f32):
772      %1 = arith.mulf %lhs, %rhs : f32
773      scf.reduce.return %1 : f32
774    }
775  }
776  return %res1, %res2 : f32, f32
777}
778
779// %res1 is used inside second scf.parallel, cannot fuse
780// CHECK-LABEL: func @reductions_use_res_inside
781// CHECK:      scf.parallel
782// CHECK:      scf.parallel
783
784// -----
785
786func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32, f32) {
787  %c2 = arith.constant 2 : index
788  %c0 = arith.constant 0 : index
789  %c1 = arith.constant 1 : index
790  %init1 = arith.constant 1.0 : f32
791  %init2 = arith.constant 2.0 : f32
792  %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
793    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
794    scf.reduce(%A_elem : f32) {
795    ^bb0(%lhs: f32, %rhs: f32):
796      %1 = arith.addf %lhs, %rhs : f32
797      scf.reduce.return %1 : f32
798    }
799  }
800  %res3 = arith.addf %res1, %init2 : f32
801  %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
802    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
803    scf.reduce(%B_elem : f32) {
804    ^bb0(%lhs: f32, %rhs: f32):
805      %1 = arith.mulf %lhs, %rhs : f32
806      scf.reduce.return %1 : f32
807    }
808  }
809  return %res1, %res2, %res3 : f32, f32, f32
810}
811
812// instruction in between the loops uses the first loop result
813// CHECK-LABEL: func @reductions_use_res_between
814// CHECK:      scf.parallel
815// CHECK:      scf.parallel
816